520be53901
- Add proper Alembic initial migration (0001_initial_schema.py) - Migrate refresh tokens from JSON file to SQLite (RefreshTokenTable) - Remove Neko-Sama provider entirely (redirects to Gupy, not a host) - Fix provider health check always showing UNKNOWN - Run check_all_health() on startup - Fix POST /providers/health/check background task bug - Add HTMX refresh after manual health check trigger - Fix anime search relevance scoring with MIN_RELEVANCE_THRESHOLD=0.5 - Replace bare 'except:' with 'except Exception:' across codebase - Add Playwright E2E test suite (12 tests, auth setup, helpers) - Fix toast container blocking clicks via pointer-events: none - Remove obsolete Jest/Vite test files and config - Clean up obsolete test_watchlist scripts - Update sonarr model comment for active providers
348 lines
10 KiB
Python
348 lines
10 KiB
Python
"""User authentication and management system with SQLModel support"""
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
from jose import jwt
|
|
from passlib.context import CryptContext
|
|
import logging
|
|
from fastapi import HTTPException
|
|
from fastapi.security import HTTPAuthorizationCredentials
|
|
from sqlmodel import Session, select
|
|
from app.database import engine
|
|
from app.models.auth import UserTable, RefreshTokenTable
|
|
from app.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Load settings at module level for easier mocking and access
|
|
settings = get_settings()
|
|
|
|
# Password hashing context
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
class UserManager:
|
|
"""Manages user storage and authentication using SQL database"""
|
|
|
|
def __init__(self):
|
|
# Database connection is managed via engine and sessions
|
|
pass
|
|
|
|
def get_user(self, username: str) -> Optional[UserTable]:
|
|
"""Get user by username"""
|
|
from app.models.watchlist import WatchlistItemTable # Force registration
|
|
|
|
with Session(engine) as session:
|
|
statement = select(UserTable).where(UserTable.username == username)
|
|
return session.exec(statement).first()
|
|
|
|
def get_user_by_id(self, user_id: str) -> Optional[UserTable]:
|
|
"""Get user by ID"""
|
|
with Session(engine) as session:
|
|
statement = select(UserTable).where(UserTable.id == user_id)
|
|
return session.exec(statement).first()
|
|
|
|
def create_user(
|
|
self,
|
|
username: str,
|
|
password: str,
|
|
email: Optional[str] = None,
|
|
full_name: Optional[str] = None,
|
|
) -> UserTable:
|
|
"""Create a new user"""
|
|
with Session(engine) as session:
|
|
# Check if user already exists
|
|
statement = select(UserTable).where(UserTable.username == username)
|
|
if session.exec(statement).first():
|
|
raise ValueError(f"Username '{username}' already exists")
|
|
|
|
# Truncate password to 72 bytes if necessary (bcrypt limitation)
|
|
password_bytes = password.encode("utf-8")
|
|
if len(password_bytes) > 72:
|
|
password = password_bytes[:72].decode("utf-8", errors="ignore")
|
|
|
|
# Hash password
|
|
hashed_password = pwd_context.hash(password)
|
|
|
|
# Create user
|
|
user = UserTable(
|
|
username=username,
|
|
email=email,
|
|
full_name=full_name,
|
|
hashed_password=hashed_password,
|
|
is_active=True,
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
session.add(user)
|
|
session.commit()
|
|
session.refresh(user)
|
|
|
|
logger.info(f"Created user: {username}")
|
|
return user
|
|
|
|
def authenticate_user(self, username: str, password: str) -> Optional[UserTable]:
|
|
"""Authenticate user with username and password"""
|
|
user = self.get_user(username)
|
|
if not user:
|
|
return None
|
|
|
|
if not pwd_context.verify(password, user.hashed_password):
|
|
return None
|
|
|
|
# Update last login
|
|
with Session(engine) as session:
|
|
db_user = session.get(UserTable, user.id)
|
|
if db_user:
|
|
db_user.last_login = datetime.now()
|
|
session.add(db_user)
|
|
session.commit()
|
|
session.refresh(db_user)
|
|
return db_user
|
|
|
|
return user
|
|
|
|
def update_user(self, user_id: str, update_data: dict) -> Optional[UserTable]:
|
|
"""Update user information"""
|
|
with Session(engine) as session:
|
|
db_user = session.get(UserTable, user_id)
|
|
if not db_user:
|
|
return None
|
|
|
|
for key, value in update_data.items():
|
|
if hasattr(db_user, key):
|
|
setattr(db_user, key, value)
|
|
|
|
session.add(db_user)
|
|
session.commit()
|
|
session.refresh(db_user)
|
|
return db_user
|
|
|
|
|
|
# Global user manager instance
|
|
user_manager = UserManager()
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against a hash"""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Hash a password for storage"""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
|
|
"""Create JWT access token"""
|
|
SECRET_KEY = settings.jwt_secret_key
|
|
ALGORITHM = settings.jwt_algorithm
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
|
|
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
def verify_token(token: str) -> Optional[str]:
|
|
"""Verify JWT token and return username"""
|
|
from jose.exceptions import JWTError
|
|
|
|
SECRET_KEY = settings.jwt_secret_key
|
|
ALGORITHM = settings.jwt_algorithm
|
|
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
return None
|
|
return username
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
# Alias for backward compatibility
|
|
get_user_from_token = verify_token
|
|
|
|
|
|
def get_current_user(credentials: HTTPAuthorizationCredentials) -> UserTable:
|
|
"""Get current user from JWT token"""
|
|
token = credentials.credentials
|
|
username = verify_token(token)
|
|
if username:
|
|
user = user_manager.get_user(username)
|
|
if not user:
|
|
raise HTTPException(status_code=401, detail="User not found")
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=401, detail="Inactive user")
|
|
return user
|
|
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
|
|
|
|
|
def _get_refresh_token(token_id: str) -> Optional[RefreshTokenTable]:
|
|
"""Get a refresh token from the database by token_id"""
|
|
with Session(engine) as session:
|
|
statement = select(RefreshTokenTable).where(RefreshTokenTable.token_id == token_id)
|
|
return session.exec(statement).first()
|
|
|
|
|
|
def _save_refresh_token(token: RefreshTokenTable):
|
|
"""Save or update a refresh token in the database"""
|
|
with Session(engine) as session:
|
|
session.add(token)
|
|
session.commit()
|
|
|
|
|
|
def _revoke_refresh_token_db(token_id: str) -> bool:
|
|
"""Revoke a refresh token in the database"""
|
|
with Session(engine) as session:
|
|
statement = select(RefreshTokenTable).where(RefreshTokenTable.token_id == token_id)
|
|
db_token = session.exec(statement).first()
|
|
if not db_token:
|
|
return False
|
|
db_token.revoked = True
|
|
db_token.revoked_at = datetime.now()
|
|
session.add(db_token)
|
|
session.commit()
|
|
return True
|
|
|
|
|
|
def _get_jwt_config() -> dict:
|
|
return {
|
|
"SECRET_KEY": settings.jwt_secret_key,
|
|
"ALGORITHM": settings.jwt_algorithm,
|
|
"ACCESS_TOKEN_EXPIRE_MINUTES": settings.access_token_expire_minutes,
|
|
"REFRESH_TOKEN_EXPIRE_DAYS": 30,
|
|
}
|
|
|
|
|
|
def create_access_refresh_tokens(data: dict) -> tuple[str, str]:
|
|
"""
|
|
Create both access and refresh tokens.
|
|
|
|
Access token: short-lived (24 hours by default)
|
|
Refresh token: long-lived (30 days by default)
|
|
|
|
Returns: (access_token, refresh_token)
|
|
"""
|
|
from jose import jwt
|
|
import secrets
|
|
|
|
jwt_config = _get_jwt_config()
|
|
|
|
# Create access token (short-lived)
|
|
access_expire = datetime.utcnow() + timedelta(
|
|
minutes=jwt_config["ACCESS_TOKEN_EXPIRE_MINUTES"]
|
|
)
|
|
access_data = data.copy()
|
|
access_data.update({"exp": access_expire, "type": "access"})
|
|
access_token = jwt.encode(
|
|
access_data, jwt_config["SECRET_KEY"], algorithm=jwt_config["ALGORITHM"]
|
|
)
|
|
|
|
# Create refresh token (long-lived)
|
|
refresh_expire = datetime.utcnow() + timedelta(
|
|
days=jwt_config["REFRESH_TOKEN_EXPIRE_DAYS"]
|
|
)
|
|
# Generate a unique token ID
|
|
token_id = secrets.token_urlsafe(32)
|
|
refresh_data = {
|
|
"sub": data["sub"],
|
|
"token_id": token_id,
|
|
"exp": refresh_expire,
|
|
"type": "refresh",
|
|
}
|
|
refresh_token = jwt.encode(
|
|
refresh_data, jwt_config["SECRET_KEY"], algorithm=jwt_config["ALGORITHM"]
|
|
)
|
|
|
|
# Store refresh token in database
|
|
db_token = RefreshTokenTable(
|
|
token_id=token_id,
|
|
username=data["sub"],
|
|
created_at=datetime.now(),
|
|
expires_at=refresh_expire,
|
|
revoked=False,
|
|
)
|
|
_save_refresh_token(db_token)
|
|
|
|
return access_token, refresh_token
|
|
|
|
|
|
def verify_refresh_token(token: str) -> Optional[str]:
|
|
"""
|
|
Verify refresh token and return username if valid.
|
|
Returns None if token is invalid or expired.
|
|
"""
|
|
from jose import jwt
|
|
from jose.exceptions import JWTError
|
|
|
|
jwt_config = _get_jwt_config()
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token, jwt_config["SECRET_KEY"], algorithms=[jwt_config["ALGORITHM"]]
|
|
)
|
|
|
|
# Verify this is a refresh token
|
|
if payload.get("type") != "refresh":
|
|
return None
|
|
|
|
username = payload.get("sub")
|
|
token_id = payload.get("token_id")
|
|
|
|
if not username or not token_id:
|
|
return None
|
|
|
|
# Check if token exists in database
|
|
stored_token = _get_refresh_token(token_id)
|
|
|
|
if not stored_token:
|
|
return None
|
|
|
|
# Verify token hasn't been revoked or expired
|
|
if stored_token.revoked:
|
|
return None
|
|
|
|
# Also check expiration in database
|
|
if stored_token.expires_at and stored_token.expires_at < datetime.now():
|
|
return None
|
|
|
|
return username
|
|
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def revoke_refresh_token(token: str) -> bool:
|
|
"""
|
|
Revoke a refresh token.
|
|
Returns True if token was revoked, False if not found.
|
|
"""
|
|
from jose import jwt
|
|
from jose.exceptions import JWTError
|
|
|
|
jwt_config = _get_jwt_config()
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token, jwt_config["SECRET_KEY"], algorithms=[jwt_config["ALGORITHM"]]
|
|
)
|
|
token_id = payload.get("token_id")
|
|
|
|
if not token_id:
|
|
return False
|
|
|
|
return _revoke_refresh_token_db(token_id)
|
|
|
|
except JWTError:
|
|
return False
|