"""User authentication and management system with SQLModel support""" import os import hashlib from datetime import datetime, timedelta from typing import Optional, Dict, List from jose import jwt from passlib.context import CryptContext import logging from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials from sqlmodel import Session, select from app.database import engine from app.models.auth import UserTable from app.config import get_settings logger = logging.getLogger(__name__) # Load settings at module level for easier mocking and access settings = get_settings() # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") class UserManager: """Manages user storage and authentication using SQL database""" def __init__(self): # Database connection is managed via engine and sessions pass def get_user(self, username: str) -> Optional[UserTable]: """Get user by username""" from app.models.watchlist import WatchlistItemTable # Force registration with Session(engine) as session: statement = select(UserTable).where(UserTable.username == username) return session.exec(statement).first() def get_user_by_id(self, user_id: str) -> Optional[UserTable]: """Get user by ID""" with Session(engine) as session: statement = select(UserTable).where(UserTable.id == user_id) return session.exec(statement).first() def create_user( self, username: str, password: str, email: Optional[str] = None, full_name: Optional[str] = None, ) -> UserTable: """Create a new user""" with Session(engine) as session: # Check if user already exists statement = select(UserTable).where(UserTable.username == username) if session.exec(statement).first(): raise ValueError(f"Username '{username}' already exists") # Truncate password to 72 bytes if necessary (bcrypt limitation) password_bytes = password.encode("utf-8") if len(password_bytes) > 72: password = password_bytes[:72].decode("utf-8", errors="ignore") # Hash password hashed_password = pwd_context.hash(password) # Create user user = UserTable( username=username, email=email, full_name=full_name, hashed_password=hashed_password, is_active=True, created_at=datetime.now(), ) session.add(user) session.commit() session.refresh(user) logger.info(f"Created user: {username}") return user def authenticate_user(self, username: str, password: str) -> Optional[UserTable]: """Authenticate user with username and password""" user = self.get_user(username) if not user: return None if not pwd_context.verify(password, user.hashed_password): return None # Update last login with Session(engine) as session: db_user = session.get(UserTable, user.id) if db_user: db_user.last_login = datetime.now() session.add(db_user) session.commit() session.refresh(db_user) return db_user return user def update_user(self, user_id: str, update_data: dict) -> Optional[UserTable]: """Update user information""" with Session(engine) as session: db_user = session.get(UserTable, user_id) if not db_user: return None for key, value in update_data.items(): if hasattr(db_user, key): setattr(db_user, key, value) session.add(db_user) session.commit() session.refresh(db_user) return db_user # Global user manager instance user_manager = UserManager() def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against a hash""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """Hash a password for storage""" return pwd_context.hash(password) def create_access_token(data: dict, expires_delta: timedelta = None) -> str: """Create JWT access token""" SECRET_KEY = settings.jwt_secret_key ALGORITHM = settings.jwt_algorithm ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt def verify_token(token: str) -> Optional[str]: """Verify JWT token and return username""" from jose.exceptions import JWTError SECRET_KEY = settings.jwt_secret_key ALGORITHM = settings.jwt_algorithm try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: return None return username except JWTError: return None # Alias for backward compatibility get_user_from_token = verify_token def get_current_user(credentials: HTTPAuthorizationCredentials) -> UserTable: """Get current user from JWT token""" token = credentials.credentials username = verify_token(token) if username: user = user_manager.get_user(username) if not user: raise HTTPException(status_code=401, detail="User not found") if not user.is_active: raise HTTPException(status_code=401, detail="Inactive user") return user raise HTTPException(status_code=401, detail="Invalid authentication credentials") # Refresh tokens storage REFRESH_TOKENS_FILE = "config/refresh_tokens.json" def _load_refresh_tokens() -> Dict[str, dict]: """Load refresh tokens from file""" import json try: if os.path.exists(REFRESH_TOKENS_FILE): with open(REFRESH_TOKENS_FILE, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"Error loading refresh tokens: {e}") return {} def _save_refresh_tokens(tokens: Dict[str, dict]): """Save refresh tokens to file""" import json try: os.makedirs(os.path.dirname(REFRESH_TOKENS_FILE), exist_ok=True) with open(REFRESH_TOKENS_FILE, "w", encoding="utf-8") as f: json.dump(tokens, f, indent=2, ensure_ascii=False, default=str) except Exception as e: logger.error(f"Error saving refresh tokens: {e}") def _get_jwt_config() -> dict: return { "SECRET_KEY": settings.jwt_secret_key, "ALGORITHM": settings.jwt_algorithm, "ACCESS_TOKEN_EXPIRE_MINUTES": settings.access_token_expire_minutes, "REFRESH_TOKEN_EXPIRE_DAYS": 30, } def create_access_refresh_tokens(data: dict) -> tuple[str, str]: """ Create both access and refresh tokens. Access token: short-lived (24 hours by default) Refresh token: long-lived (30 days by default) Returns: (access_token, refresh_token) """ from jose import jwt import secrets jwt_config = _get_jwt_config() # Create access token (short-lived) access_expire = datetime.utcnow() + timedelta( minutes=jwt_config["ACCESS_TOKEN_EXPIRE_MINUTES"] ) access_data = data.copy() access_data.update({"exp": access_expire, "type": "access"}) access_token = jwt.encode( access_data, jwt_config["SECRET_KEY"], algorithm=jwt_config["ALGORITHM"] ) # Create refresh token (long-lived) refresh_expire = datetime.utcnow() + timedelta( days=jwt_config["REFRESH_TOKEN_EXPIRE_DAYS"] ) # Generate a unique token ID token_id = secrets.token_urlsafe(32) refresh_data = { "sub": data["sub"], "token_id": token_id, "exp": refresh_expire, "type": "refresh", } refresh_token = jwt.encode( refresh_data, jwt_config["SECRET_KEY"], algorithm=jwt_config["ALGORITHM"] ) # Store refresh token mapping refresh_tokens = _load_refresh_tokens() refresh_tokens[token_id] = { "username": data["sub"], "token_id": token_id, "created_at": datetime.now().isoformat(), "expires_at": refresh_expire.isoformat(), } _save_refresh_tokens(refresh_tokens) return access_token, refresh_token def verify_refresh_token(token: str) -> Optional[str]: """ Verify refresh token and return username if valid. Returns None if token is invalid or expired. """ from jose import jwt from jose.exceptions import JWTError jwt_config = _get_jwt_config() try: payload = jwt.decode( token, jwt_config["SECRET_KEY"], algorithms=[jwt_config["ALGORITHM"]] ) # Verify this is a refresh token if payload.get("type") != "refresh": return None username = payload.get("sub") token_id = payload.get("token_id") if not username or not token_id: return None # Check if token exists in storage refresh_tokens = _load_refresh_tokens() stored_token = refresh_tokens.get(token_id) if not stored_token: return None # Verify token hasn't been revoked or expired if stored_token.get("revoked"): return None return username except JWTError: return None def revoke_refresh_token(token: str) -> bool: """ Revoke a refresh token. Returns True if token was revoked, False if not found. """ from jose import jwt from jose.exceptions import JWTError jwt_config = _get_jwt_config() try: payload = jwt.decode( token, jwt_config["SECRET_KEY"], algorithms=[jwt_config["ALGORITHM"]] ) token_id = payload.get("token_id") if not token_id: return False refresh_tokens = _load_refresh_tokens() if token_id in refresh_tokens: refresh_tokens[token_id]["revoked"] = True refresh_tokens[token_id]["revoked_at"] = datetime.now().isoformat() _save_refresh_tokens(refresh_tokens) return True return False except JWTError: return False