feat: migrate persistence from JSON to SQLModel (Phase 1)
CI / Test (Python 3.11) (push) Has been cancelled
CI / Test (Python 3.12) (push) Has been cancelled
CI / Lint (push) Has been cancelled
CI / Type Check (push) Has been cancelled
CI / Summary (push) Has been cancelled

- Integrated SQLModel with SQLite for robust data persistence
- Refactored UserManager and WatchlistManager to use SQL queries
- Migrated models to SQLModel with relationships and primary keys
- Updated test suite with in-memory database isolation
- Removed deprecated JSON storage files
This commit is contained in:
root
2026-03-24 10:40:36 +00:00
parent d4d8d8a3b6
commit 29c7040b20
13 changed files with 596 additions and 1165 deletions
+84 -100
View File
@@ -1,120 +1,118 @@
"""User authentication and management system""" """User authentication and management system with SQLModel support"""
import json
import os import os
import hashlib import hashlib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, Dict from typing import Optional, Dict, List
from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
import logging import logging
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials from fastapi.security import HTTPAuthorizationCredentials
from sqlmodel import Session, select
from app.database import engine
from app.models.auth import UserTable
from app.config import get_settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Load settings at module level for easier mocking and access
settings = get_settings()
# Password hashing context # Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Users database file
USERS_DB_FILE = "config/users.json"
class UserManager: class UserManager:
"""Manages user storage and authentication""" """Manages user storage and authentication using SQL database"""
def __init__(self, db_file: str = USERS_DB_FILE): def __init__(self):
self.db_file = db_file # Database connection is managed via engine and sessions
self.users: Dict[str, dict] = {} pass
self._load_users()
def _load_users(self): def get_user(self, username: str) -> Optional[UserTable]:
"""Load users from JSON file"""
try:
if os.path.exists(self.db_file):
with open(self.db_file, "r", encoding="utf-8") as f:
self.users = json.load(f)
logger.info(f"Loaded {len(self.users)} users from database")
except Exception as e:
logger.error(f"Error loading users: {e}")
self.users = {}
def _save_users(self):
try:
os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
temp_file = f"{self.db_file}.tmp"
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.users, f, indent=2, ensure_ascii=False, default=str)
os.replace(temp_file, self.db_file)
logger.info(f"Saved {len(self.users)} users to database")
except Exception as e:
logger.error(f"Error saving users: {e}")
def get_user(self, username: str) -> Optional[dict]:
"""Get user by username""" """Get user by username"""
return self.users.get(username) with Session(engine) as session:
statement = select(UserTable).where(UserTable.username == username)
return session.exec(statement).first()
def get_user_by_id(self, user_id: str) -> Optional[dict]: def get_user_by_id(self, user_id: str) -> Optional[UserTable]:
"""Get user by ID""" """Get user by ID"""
for user in self.users.values(): with Session(engine) as session:
if user.get("id") == user_id: statement = select(UserTable).where(UserTable.id == user_id)
return user return session.exec(statement).first()
return None
def create_user( def create_user(
self, username: str, password: str, email: str = None, full_name: str = None self, username: str, password: str, email: str = None, full_name: str = None
) -> dict: ) -> UserTable:
"""Create a new user""" """Create a new user"""
if username in self.users: with Session(engine) as session:
raise ValueError(f"Username '{username}' already exists") # Check if user already exists
statement = select(UserTable).where(UserTable.username == username)
if session.exec(statement).first():
raise ValueError(f"Username '{username}' already exists")
# Truncate password to 72 bytes if necessary (bcrypt limitation) # Truncate password to 72 bytes if necessary (bcrypt limitation)
password_bytes = password.encode("utf-8") password_bytes = password.encode("utf-8")
if len(password_bytes) > 72: if len(password_bytes) > 72:
password = password_bytes[:72].decode("utf-8", errors="ignore") password = password_bytes[:72].decode("utf-8", errors="ignore")
# Hash password # Hash password
hashed_password = pwd_context.hash(password) hashed_password = pwd_context.hash(password)
# Create user # Create user
user = { user = UserTable(
"id": hashlib.sha256(username.encode()).hexdigest()[:32], username=username,
"username": username, email=email,
"email": email, full_name=full_name,
"full_name": full_name, hashed_password=hashed_password,
"hashed_password": hashed_password, is_active=True,
"is_active": True, created_at=datetime.now()
"created_at": datetime.now().isoformat(), )
"last_login": None,
}
self.users[username] = user session.add(user)
self._save_users() session.commit()
session.refresh(user)
logger.info(f"Created user: {username}") logger.info(f"Created user: {username}")
return user return user
def authenticate_user(self, username: str, password: str) -> Optional[dict]: def authenticate_user(self, username: str, password: str) -> Optional[UserTable]:
"""Authenticate user with username and password""" """Authenticate user with username and password"""
user = self.get_user(username) user = self.get_user(username)
if not user: if not user:
return None return None
if not pwd_context.verify(password, user["hashed_password"]): if not pwd_context.verify(password, user.hashed_password):
return None return None
# Update last login # Update last login
user["last_login"] = datetime.now().isoformat() with Session(engine) as session:
self._save_users() db_user = session.get(UserTable, user.id)
if db_user:
db_user.last_login = datetime.now()
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user
return user return user
def update_last_login(self, username: str): def update_user(self, user_id: str, update_data: dict) -> Optional[UserTable]:
"""Update user's last login time""" """Update user information"""
user = self.get_user(username) with Session(engine) as session:
if user: db_user = session.get(UserTable, user_id)
user["last_login"] = datetime.now().isoformat() if not db_user:
self._save_users() return None
for key, value in update_data.items():
if hasattr(db_user, key):
setattr(db_user, key, value)
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user
# Global user manager instance # Global user manager instance
@@ -131,27 +129,11 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)
def _get_jwt_config() -> dict:
"""Get JWT configuration from settings"""
from app.config import get_settings
settings = get_settings()
return {
"SECRET_KEY": settings.jwt_secret_key,
"ALGORITHM": settings.jwt_algorithm,
"ACCESS_TOKEN_EXPIRE_MINUTES": settings.access_token_expire_minutes,
"REFRESH_TOKEN_EXPIRE_DAYS": settings.refresh_token_expire_days,
}
def create_access_token(data: dict, expires_delta: timedelta = None) -> str: def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
"""Create JWT access token""" """Create JWT access token"""
from jose import jwt SECRET_KEY = settings.jwt_secret_key
ALGORITHM = settings.jwt_algorithm
jwt_config = _get_jwt_config() ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
SECRET_KEY = jwt_config["SECRET_KEY"]
ALGORITHM = jwt_config["ALGORITHM"]
ACCESS_TOKEN_EXPIRE_MINUTES = jwt_config["ACCESS_TOKEN_EXPIRE_MINUTES"]
to_encode = data.copy() to_encode = data.copy()
@@ -168,12 +150,10 @@ def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
def verify_token(token: str) -> Optional[str]: def verify_token(token: str) -> Optional[str]:
"""Verify JWT token and return username""" """Verify JWT token and return username"""
from jose import jwt
from jose.exceptions import JWTError from jose.exceptions import JWTError
jwt_config = _get_jwt_config() SECRET_KEY = settings.jwt_secret_key
SECRET_KEY = jwt_config["SECRET_KEY"] ALGORITHM = settings.jwt_algorithm
ALGORITHM = jwt_config["ALGORITHM"]
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
@@ -189,7 +169,7 @@ def verify_token(token: str) -> Optional[str]:
get_user_from_token = verify_token get_user_from_token = verify_token
def get_current_user(credentials: HTTPAuthorizationCredentials) -> dict: def get_current_user(credentials: HTTPAuthorizationCredentials) -> UserTable:
"""Get current user from JWT token""" """Get current user from JWT token"""
token = credentials.credentials token = credentials.credentials
username = verify_token(token) username = verify_token(token)
@@ -197,16 +177,19 @@ def get_current_user(credentials: HTTPAuthorizationCredentials) -> dict:
user = user_manager.get_user(username) user = user_manager.get_user(username)
if not user: if not user:
raise HTTPException(status_code=401, detail="User not found") raise HTTPException(status_code=401, detail="User not found")
if not user.get("is_active", True): if not user.is_active:
raise HTTPException(status_code=401, detail="Inactive user") raise HTTPException(status_code=401, detail="Inactive user")
return user return user
raise HTTPException(status_code=401, detail="Invalid authentication credentials") raise HTTPException(status_code=401, detail="Invalid authentication credentials")
# Refresh tokens storage # Refresh tokens storage
REFRESH_TOKENS_FILE = "config/refresh_tokens.json" REFRESH_TOKENS_FILE = "config/refresh_tokens.json"
def _load_refresh_tokens() -> Dict[str, dict]: def _load_refresh_tokens() -> Dict[str, dict]:
"""Load refresh tokens from file""" """Load refresh tokens from file"""
import json
try: try:
if os.path.exists(REFRESH_TOKENS_FILE): if os.path.exists(REFRESH_TOKENS_FILE):
with open(REFRESH_TOKENS_FILE, 'r', encoding='utf-8') as f: with open(REFRESH_TOKENS_FILE, 'r', encoding='utf-8') as f:
@@ -218,6 +201,7 @@ def _load_refresh_tokens() -> Dict[str, dict]:
def _save_refresh_tokens(tokens: Dict[str, dict]): def _save_refresh_tokens(tokens: Dict[str, dict]):
"""Save refresh tokens to file""" """Save refresh tokens to file"""
import json
try: try:
os.makedirs(os.path.dirname(REFRESH_TOKENS_FILE), exist_ok=True) os.makedirs(os.path.dirname(REFRESH_TOKENS_FILE), exist_ok=True)
with open(REFRESH_TOKENS_FILE, 'w', encoding='utf-8') as f: with open(REFRESH_TOKENS_FILE, 'w', encoding='utf-8') as f:
+31
View File
@@ -0,0 +1,31 @@
"""Database configuration and session management using SQLModel"""
import os
from typing import Generator
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, create_engine
from app.config import get_settings
settings = get_settings()
# Database URL can be overridden by environment variable DATABASE_URL
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./ohm_streaming.db")
# Create the engine
# connect_args={"check_same_thread": False} is required for SQLite and FastAPI
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
def create_db_and_tables():
"""Create the database and tables based on the models"""
# Import all models here to ensure they are registered with SQLModel.metadata
from app.models.auth import UserTable
from app.models.watchlist import WatchlistItemTable
# Add other models as they are migrated
SQLModel.metadata.create_all(engine)
def get_session() -> Generator[Session, None, None]:
"""Dependency for getting a database session"""
with Session(engine) as session:
yield session
+36 -14
View File
@@ -1,15 +1,41 @@
"""Authentication models for user management""" """Authentication models for user management with SQLModel support"""
from pydantic import BaseModel, EmailStr, Field import uuid
from typing import Optional from pydantic import BaseModel, EmailStr, Field as PydanticField
from typing import Optional, List
from datetime import datetime from datetime import datetime
from sqlmodel import SQLModel, Field, Relationship
class UserCreate(BaseModel): class UserBase(SQLModel):
"""Schema for user registration""" """Base schema for user data"""
username: str = Field(..., min_length=3, max_length=50) username: str = Field(index=True, unique=True, min_length=3, max_length=50)
email: Optional[EmailStr] = None email: Optional[str] = Field(default=None, index=True)
password: str = Field(..., min_length=6)
full_name: Optional[str] = None full_name: Optional[str] = None
is_active: bool = Field(default=True)
class UserTable(UserBase, table=True):
"""Database table for users"""
__tablename__ = "users"
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
primary_key=True,
index=True,
nullable=False
)
hashed_password: str
created_at: datetime = Field(default_factory=datetime.now)
last_login: Optional[datetime] = None
# Relationships
watchlist_items: List["WatchlistItemTable"] = Relationship(back_populates="user")
class UserCreate(UserBase):
"""Schema for user registration"""
password: str = PydanticField(..., min_length=6)
email: Optional[EmailStr] = None
class UserLogin(BaseModel): class UserLogin(BaseModel):
@@ -18,13 +44,9 @@ class UserLogin(BaseModel):
password: str password: str
class User(BaseModel): class User(UserBase):
"""Schema for user data""" """Schema for user data (API Response)"""
id: str id: str
username: str
email: Optional[str] = None
full_name: Optional[str] = None
is_active: bool = True
created_at: datetime created_at: datetime
last_login: Optional[datetime] = None last_login: Optional[datetime] = None
+81 -45
View File
@@ -1,8 +1,11 @@
"""Pydantic models for Watchlist and Auto-Download system""" """Models for Watchlist and Auto-Download system with SQLModel support"""
from pydantic import BaseModel, Field import uuid
from typing import Optional, Literal import json
from pydantic import BaseModel, Field as PydanticField
from typing import Optional, Literal, List
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from sqlmodel import SQLModel, Field, Relationship, Column, String
class WatchlistStatus(str, Enum): class WatchlistStatus(str, Enum):
@@ -21,34 +24,80 @@ class QualityPreference(str, Enum):
P480 = "480p" # SD P480 = "480p" # SD
class WatchlistItem(BaseModel): class WatchlistItemBase(SQLModel):
"""An anime being tracked for automatic episode downloads""" """Base schema for watchlist items"""
id: str = Field(..., description="Unique identifier (UUID)") anime_title: str = Field(index=True)
user_id: str = Field(..., description="User ID who owns this watchlist item") anime_url: str
anime_title: str = Field(..., description="Title of the anime") provider_id: str
anime_url: str = Field(..., description="URL to the anime page") lang: str = Field(default="vostfr")
provider_id: str = Field(..., description="Provider ID (animesama, nekosama, etc.)")
lang: Literal["vostfr", "vf"] = Field(default="vostfr", description="Language preference")
# Tracking state # Tracking state
last_checked: Optional[datetime] = Field(None, description="Last time we checked for new episodes") last_checked: Optional[datetime] = None
last_episode_downloaded: int = Field(default=0, description="Last episode number downloaded") last_episode_downloaded: int = Field(default=0)
total_episodes: Optional[int] = Field(None, description="Total episodes if known") total_episodes: Optional[int] = None
# Settings # Settings
auto_download: bool = Field(default=True, description="Automatically download new episodes") auto_download: bool = Field(default=True)
quality_preference: QualityPreference = Field(default=QualityPreference.AUTO, description="Preferred quality") quality_preference: QualityPreference = Field(default=QualityPreference.AUTO)
status: WatchlistStatus = Field(default=WatchlistStatus.ACTIVE, description="Tracking status") status: WatchlistStatus = Field(default=WatchlistStatus.ACTIVE)
# Metadata # Metadata
poster_image: Optional[str] = Field(None, description="URL to poster image") poster_image: Optional[str] = None
cover_image: Optional[str] = Field(None, description="URL to cover image") cover_image: Optional[str] = None
synopsis: Optional[str] = Field(None, description="Anime synopsis") synopsis: Optional[str] = None
genres: list[str] = Field(default_factory=list, description="Anime genres")
# Timestamps # Timestamps
added_at: datetime = Field(default_factory=datetime.now, description="When added to watchlist") added_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time") updated_at: datetime = Field(default_factory=datetime.now)
class WatchlistItemTable(WatchlistItemBase, table=True):
"""Database table for watchlist items"""
__tablename__ = "watchlist_items"
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
primary_key=True,
index=True,
nullable=False
)
user_id: str = Field(foreign_key="users.id", index=True)
# Store list as JSON string in SQLite
genres_json: Optional[str] = Field(default="[]", sa_column=Column(String))
@property
def genres(self) -> List[str]:
return json.loads(self.genres_json or "[]")
@genres.setter
def genres(self, value: List[str]):
self.genres_json = json.dumps(value or [])
# Relationships
user: Optional["UserTable"] = Relationship(back_populates="watchlist_items")
class WatchlistItem(BaseModel):
"""An anime being tracked for automatic episode downloads (API Response)"""
id: str
user_id: str
anime_title: str
anime_url: str
provider_id: str
lang: str
last_checked: Optional[datetime] = None
last_episode_downloaded: int = 0
total_episodes: Optional[int] = None
auto_download: bool = True
quality_preference: QualityPreference = QualityPreference.AUTO
status: WatchlistStatus = WatchlistStatus.ACTIVE
poster_image: Optional[str] = None
cover_image: Optional[str] = None
synopsis: Optional[str] = None
genres: List[str] = []
added_at: datetime
updated_at: datetime
class Config: class Config:
json_encoders = { json_encoders = {
@@ -64,12 +113,10 @@ class WatchlistItemCreate(BaseModel):
lang: Literal["vostfr", "vf"] = "vostfr" lang: Literal["vostfr", "vf"] = "vostfr"
auto_download: bool = True auto_download: bool = True
quality_preference: QualityPreference = QualityPreference.AUTO quality_preference: QualityPreference = QualityPreference.AUTO
# Optional metadata
poster_image: Optional[str] = None poster_image: Optional[str] = None
cover_image: Optional[str] = None cover_image: Optional[str] = None
synopsis: Optional[str] = None synopsis: Optional[str] = None
genres: list[str] = [] genres: List[str] = []
class WatchlistItemUpdate(BaseModel): class WatchlistItemUpdate(BaseModel):
@@ -96,26 +143,15 @@ class AutoDownloadResult(BaseModel):
watchlist_item_id: str watchlist_item_id: str
anime_title: str anime_title: str
new_episodes_found: int new_episodes_found: int
episodes_downloaded: list[int] = Field(default_factory=list) episodes_downloaded: list[int] = PydanticField(default_factory=list)
episodes_failed: list[tuple[int, str]] = Field(default_factory=list) # (episode_number, error_message) episodes_failed: list[tuple[int, str]] = PydanticField(default_factory=list)
checked_at: datetime = Field(default_factory=datetime.now) checked_at: datetime = PydanticField(default_factory=datetime.now)
class WatchlistSettings(BaseModel): class WatchlistSettings(BaseModel):
"""Global watchlist settings""" """Global watchlist settings"""
check_interval_hours: int = Field(default=6, ge=1, le=168, description="Check interval (1-168 hours)") check_interval_hours: int = PydanticField(default=6, ge=1, le=168)
auto_download_enabled: bool = Field(default=True, description="Global auto-download toggle") auto_download_enabled: bool = PydanticField(default=True)
max_concurrent_auto_downloads: int = Field(default=2, ge=1, le=10, description="Max concurrent auto-downloads") max_concurrent_auto_downloads: int = PydanticField(default=2, ge=1, le=10)
notify_on_new_episodes: bool = Field(default=False, description="Send notifications for new episodes") notify_on_new_episodes: bool = PydanticField(default=False)
include_completed_anime: bool = Field(default=False, description="Check completed anime too") include_completed_anime: bool = PydanticField(default=False)
class Config:
json_schema_extra = {
"example": {
"check_interval_hours": 6,
"auto_download_enabled": True,
"max_concurrent_auto_downloads": 2,
"notify_on_new_episodes": False,
"include_completed_anime": False
}
}
+141 -159
View File
@@ -1,4 +1,4 @@
"""Watchlist management system for automatic episode tracking and downloading""" """Watchlist management system for automatic episode tracking and downloading with SQLModel support"""
import json import json
import os import os
import uuid import uuid
@@ -7,8 +7,11 @@ from datetime import datetime, timedelta
from typing import List, Optional, Dict from typing import List, Optional, Dict
from pathlib import Path from pathlib import Path
from sqlmodel import Session, select
from app.database import engine
from app.models.watchlist import ( from app.models.watchlist import (
WatchlistItem, WatchlistItem,
WatchlistItemTable,
WatchlistItemCreate, WatchlistItemCreate,
WatchlistItemUpdate, WatchlistItemUpdate,
WatchlistStatus, WatchlistStatus,
@@ -19,55 +22,18 @@ from app.models.watchlist import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Watchlist database file # Settings file remains JSON for simplicity for now
WATCHLIST_DB_FILE = "config/watchlist.json"
WATCHLIST_SETTINGS_FILE = "config/watchlist_settings.json" WATCHLIST_SETTINGS_FILE = "config/watchlist_settings.json"
class WatchlistManager: class WatchlistManager:
"""Manages user watchlist for automatic episode downloads""" """Manages user watchlist for automatic episode downloads using SQL database"""
def __init__(self, db_file: str = WATCHLIST_DB_FILE): def __init__(self):
self.db_file = db_file
self.settings_file = WATCHLIST_SETTINGS_FILE self.settings_file = WATCHLIST_SETTINGS_FILE
self.watchlist: Dict[str, WatchlistItem] = {}
self.settings: Optional[WatchlistSettings] = None self.settings: Optional[WatchlistSettings] = None
self._load_watchlist()
self._load_settings() self._load_settings()
def _load_watchlist(self):
"""Load watchlist from JSON file"""
try:
if os.path.exists(self.db_file):
with open(self.db_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.watchlist = {
item_id: WatchlistItem(**item_data)
for item_id, item_data in data.items()
}
logger.info(f"Loaded {len(self.watchlist)} items from watchlist")
else:
self.watchlist = {}
logger.info("Watchlist database not found, starting with empty watchlist")
except Exception as e:
logger.error(f"Error loading watchlist: {e}")
self.watchlist = {}
def _save_watchlist(self):
try:
os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
data = {
item_id: item.model_dump(mode='json')
for item_id, item in self.watchlist.items()
}
temp_file = f"{self.db_file}.tmp"
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False, default=str)
os.replace(temp_file, self.db_file)
logger.debug(f"Saved {len(self.watchlist)} items to watchlist")
except Exception as e:
logger.error(f"Error saving watchlist: {e}")
def _load_settings(self): def _load_settings(self):
"""Load watchlist settings from JSON file""" """Load watchlist settings from JSON file"""
try: try:
@@ -95,159 +61,175 @@ class WatchlistManager:
except Exception as e: except Exception as e:
logger.error(f"Error saving settings: {e}") logger.error(f"Error saving settings: {e}")
def _to_api_model(self, db_item: WatchlistItemTable) -> WatchlistItem:
"""Convert database table model to API response model"""
data = db_item.model_dump()
data["genres"] = db_item.genres
return WatchlistItem(**data)
def get_all(self, user_id: Optional[str] = None, status: Optional[WatchlistStatus] = None) -> List[WatchlistItem]: def get_all(self, user_id: Optional[str] = None, status: Optional[WatchlistStatus] = None) -> List[WatchlistItem]:
"""Get all watchlist items, optionally filtered by user and status""" """Get all watchlist items, optionally filtered by user and status"""
items = list(self.watchlist.values()) with Session(engine) as session:
statement = select(WatchlistItemTable)
if user_id:
statement = statement.where(WatchlistItemTable.user_id == user_id)
if status:
statement = statement.where(WatchlistItemTable.status == status)
if user_id: # Sort by added_at descending
items = [item for item in items if item.user_id == user_id] statement = statement.order_by(WatchlistItemTable.added_at.desc())
if status: db_items = session.exec(statement).all()
items = [item for item in items if item.status == status] return [self._to_api_model(item) for item in db_items]
# Sort by added_at descending
items.sort(key=lambda x: x.added_at, reverse=True)
return items
def get_by_id(self, item_id: str) -> Optional[WatchlistItem]: def get_by_id(self, item_id: str) -> Optional[WatchlistItem]:
"""Get a watchlist item by ID""" """Get a specific watchlist item by ID"""
return self.watchlist.get(item_id) with Session(engine) as session:
db_item = session.get(WatchlistItemTable, item_id)
if db_item:
return self._to_api_model(db_item)
return None
def get_by_anime_url(self, anime_url: str, user_id: str) -> Optional[WatchlistItem]: def get_by_anime_url(self, anime_url: str, user_id: str) -> Optional[WatchlistItem]:
"""Get a watchlist item by anime URL and user ID""" """Get a watchlist item by anime URL and user ID"""
for item in self.watchlist.values(): with Session(engine) as session:
if item.anime_url == anime_url and item.user_id == user_id: statement = select(WatchlistItemTable).where(
return item WatchlistItemTable.anime_url == anime_url,
return None WatchlistItemTable.user_id == user_id
)
db_item = session.exec(statement).first()
if db_item:
return self._to_api_model(db_item)
return None
def create(self, user_id: str, item_data: WatchlistItemCreate) -> WatchlistItem: def add(self, user_id: str, item_create: WatchlistItemCreate) -> WatchlistItem:
"""Create a new watchlist item""" """Add a new anime to the watchlist"""
# Check if already exists # Check if already in watchlist for this user
existing = self.get_by_anime_url(item_data.anime_url, user_id) existing = self.get_by_anime_url(item_create.anime_url, user_id)
if existing: if existing:
raise ValueError(f"Anime already in watchlist (ID: {existing.id})") return existing
# Create new item with Session(engine) as session:
item_id = str(uuid.uuid4()) # Create new item
now = datetime.now() db_item = WatchlistItemTable(
user_id=user_id,
anime_title=item_create.anime_title,
anime_url=item_create.anime_url,
provider_id=item_create.provider_id,
lang=item_create.lang,
auto_download=item_create.auto_download,
quality_preference=item_create.quality_preference,
poster_image=item_create.poster_image,
cover_image=item_create.cover_image,
synopsis=item_create.synopsis,
status=WatchlistStatus.ACTIVE,
added_at=datetime.now(),
updated_at=datetime.now(),
last_episode_downloaded=0
)
db_item.genres = item_create.genres
watchlist_item = WatchlistItem( session.add(db_item)
id=item_id, session.commit()
user_id=user_id, session.refresh(db_item)
anime_title=item_data.anime_title,
anime_url=item_data.anime_url,
provider_id=item_data.provider_id,
lang=item_data.lang,
auto_download=item_data.auto_download,
quality_preference=item_data.quality_preference,
status=WatchlistStatus.ACTIVE,
poster_image=item_data.poster_image,
cover_image=item_data.cover_image,
synopsis=item_data.synopsis,
genres=item_data.genres,
added_at=now,
updated_at=now,
last_checked=None,
last_episode_downloaded=0,
total_episodes=None
)
self.watchlist[item_id] = watchlist_item logger.info(f"Added {db_item.anime_title} to watchlist for user {user_id}")
self._save_watchlist() return self._to_api_model(db_item)
logger.info(f"Added anime to watchlist: {watchlist_item.anime_title} (ID: {item_id})")
return watchlist_item # Alias for backward compatibility if needed
add_item = add
def update(self, item_id: str, update_data) -> Optional[WatchlistItem]: def update(self, item_id: str, update_data) -> Optional[WatchlistItem]:
"""Update a watchlist item """Update a watchlist item"""
with Session(engine) as session:
db_item = session.get(WatchlistItemTable, item_id)
if not db_item:
return None
Args: # Handle both dict and WatchlistItemUpdate
item_id: Item ID to update if isinstance(update_data, dict):
update_data: WatchlistItemUpdate object or dict with fields to update update_dict = update_data
""" else:
item = self.watchlist.get(item_id) update_dict = update_data.model_dump(exclude_unset=True)
if not item:
return None
# Handle both dict and WatchlistItemUpdate for key, value in update_dict.items():
if isinstance(update_data, dict): if hasattr(db_item, key):
update_dict = update_data setattr(db_item, key, value)
else:
update_dict = update_data.model_dump(exclude_unset=True)
# Update fields db_item.updated_at = datetime.now()
for field, value in update_dict.items():
if value is not None:
setattr(item, field, value)
item.updated_at = datetime.now() session.add(db_item)
self._save_watchlist() session.commit()
logger.info(f"Updated watchlist item: {item_id}") session.refresh(db_item)
return item
logger.info(f"Updated watchlist item: {item_id}")
return self._to_api_model(db_item)
# Alias for backward compatibility
update_item = update
def delete(self, item_id: str) -> bool: def delete(self, item_id: str) -> bool:
"""Delete a watchlist item""" """Remove an item from the watchlist"""
if item_id in self.watchlist: with Session(engine) as session:
del self.watchlist[item_id] db_item = session.get(WatchlistItemTable, item_id)
self._save_watchlist() if not db_item:
logger.info(f"Deleted watchlist item: {item_id}") return False
session.delete(db_item)
session.commit()
logger.info(f"Deleted item {item_id} from watchlist")
return True return True
return False
def update_check_time(self, item_id: str, last_episode: int) -> Optional[WatchlistItem]: def update_last_checked(self, item_id: str, last_episode: Optional[int] = None):
"""Update last_checked time and last_episode_downloaded""" """Update the last_checked timestamp and optionally last episode for an item"""
item = self.watchlist.get(item_id) with Session(engine) as session:
if not item: db_item = session.get(WatchlistItemTable, item_id)
return None if db_item:
db_item.last_checked = datetime.now()
if last_episode is not None:
db_item.last_episode_downloaded = last_episode
session.add(db_item)
session.commit()
item.last_checked = datetime.now() # Alias for backward compatibility
item.last_episode_downloaded = max(item.last_episode_downloaded, last_episode) update_check_time = update_last_checked
item.updated_at = datetime.now()
self._save_watchlist()
return item
def get_settings(self) -> WatchlistSettings: def get_due_items(self) -> List[WatchlistItem]:
"""Get watchlist settings""" """Get all items that are due for a check based on settings"""
if not self.settings: interval = timedelta(hours=self.settings.check_interval_hours)
self.settings = WatchlistSettings() now = datetime.now()
return self.settings
with Session(engine) as session:
statement = select(WatchlistItemTable).where(
(WatchlistItemTable.status == WatchlistStatus.ACTIVE)
)
db_items = session.exec(statement).all()
due_items = []
for item in db_items:
if not item.last_checked or (item.last_checked + interval) < now:
due_items.append(self._to_api_model(item))
return due_items
def update_settings(self, settings: WatchlistSettings) -> WatchlistSettings: def update_settings(self, settings: WatchlistSettings) -> WatchlistSettings:
"""Update watchlist settings""" """Update global watchlist settings"""
self.settings = settings self.settings = settings
self._save_settings() self._save_settings()
logger.info("Updated watchlist settings") logger.info("Updated watchlist settings")
return self.settings return self.settings
def get_due_for_check(self, check_interval_hours: Optional[int] = None) -> List[WatchlistItem]: def get_stats(self, user_id: str) -> Dict:
"""Get items that are due for checking""" """Get statistics for a user's watchlist"""
if check_interval_hours is None:
check_interval_hours = self.settings.check_interval_hours
cutoff_time = datetime.now() - timedelta(hours=check_interval_hours)
due_items = []
for item in self.watchlist.values():
# Only check active items with auto_download enabled
if item.status != WatchlistStatus.ACTIVE or not item.auto_download:
continue
# Check if due
if item.last_checked is None or item.last_checked < cutoff_time:
due_items.append(item)
logger.info(f"Found {len(due_items)} items due for check")
return due_items
def get_stats(self, user_id: Optional[str] = None) -> Dict:
"""Get watchlist statistics"""
items = self.get_all(user_id=user_id) items = self.get_all(user_id=user_id)
stats = { stats = {
"total": len(items), "total_items": len(items),
"active": len([i for i in items if i.status == WatchlistStatus.ACTIVE]), "active_items": len([i for i in items if i.status == WatchlistStatus.ACTIVE]),
"paused": len([i for i in items if i.status == WatchlistStatus.PAUSED]), "paused_items": len([i for i in items if i.status == WatchlistStatus.PAUSED]),
"completed": len([i for i in items if i.status == WatchlistStatus.COMPLETED]), "completed_items": len([i for i in items if i.status == WatchlistStatus.COMPLETED]),
"auto_download_enabled": len([i for i in items if i.auto_download]), "total_episodes_downloaded": sum(i.last_episode_downloaded for i in items),
"providers": {} "providers": {}
} }
-92
View File
@@ -1,92 +0,0 @@
{
"testuser": {
"id": "ae5deb822e0d71992900471a7199d0d9",
"username": "testuser",
"email": "test@example.com",
"full_name": "Test User",
"hashed_password": "$2b$12$gDgt6xCBS4y2FgNrCk0JU.cn8SPwrNo6vIebDSQlkfeDmvP43safy",
"is_active": true,
"created_at": "2026-01-26T11:32:14.262592",
"last_login": "2026-01-26T12:18:26.818435"
},
"apitest": {
"id": "e81cbf18a5239377aa4972773d34cc2b",
"username": "apitest",
"email": "apitest@example.com",
"full_name": "API Test User",
"hashed_password": "$2b$12$sJWQhQ0S/rMX3VJiEOMstuusfPgCvXN8zq/lCnKocL28PRomX9RJ6",
"is_active": true,
"created_at": "2026-01-26T11:32:46.943188",
"last_login": "2026-01-26T11:32:47.140656"
},
"testuser_final": {
"id": "2b4aade7e46060f88e36ae92ba767545",
"username": "testuser_final",
"email": "final@test.com",
"full_name": "Final Test User",
"hashed_password": "$2b$12$wN7Saj99c4B39O5Y2XNQ4eVuPm7o6b8eeJ1TxFrvy5.g7ycyh9rKm",
"is_active": true,
"created_at": "2026-01-26T11:33:45.726090",
"last_login": "2026-01-26T11:33:46.548491"
},
"webtest": {
"id": "2cae3fde0b88cf1274fe58ec039302cc",
"username": "webtest",
"email": null,
"full_name": null,
"hashed_password": "$2b$12$2Rr32QkYCj05GGAOQGua0umCHYRyPnvcDVXPbYaSu5SmYaohXi08a",
"is_active": true,
"created_at": "2026-01-26T11:44:09.995999",
"last_login": "2026-01-26T11:44:10.190329"
},
"roman": {
"id": "4eaae75f1df2f52bda44f6b18a400542",
"username": "roman",
"email": null,
"full_name": null,
"hashed_password": "$2b$12$IC9kz7kxf1mQPhsdveFnyOX3V5Q1.pB9/uqCKWI7nhn.SYamtvxCC",
"is_active": true,
"created_at": "2026-01-26T12:15:58.008205",
"last_login": "2026-03-23T13:29:45.076454"
},
"testuser999": {
"id": "f9abf4b8aa96d5116807ac1cf8540418",
"username": "testuser999",
"email": null,
"full_name": null,
"hashed_password": "$2b$12$y2uy62IR0xVmCcUmQ8gL6.nkvFthjyuRGxtSKh6CD5soey6T/IFu6",
"is_active": true,
"created_at": "2026-01-26T12:18:26.623497",
"last_login": null
},
"flowtest": {
"id": "4b797133389d3f5042f13aac323a8840",
"username": "flowtest",
"email": "flow@test.com",
"full_name": null,
"hashed_password": "$2b$12$Dcb7fKZPycLRsW851m9pk.1ZeyHcX65PAnb5HqLY74cJKonUfDDOC",
"is_active": true,
"created_at": "2026-01-26T12:18:50.138613",
"last_login": "2026-01-26T12:18:50.332004"
},
"e2etest": {
"id": "37a97310cedfe6ae001033c2b9832f6c",
"username": "e2etest",
"email": null,
"full_name": null,
"hashed_password": "$2b$12$uV9AW1qrbLC2tOCk1Gs4x.clk1v7jPNteHmn/Nby/Lelopb9Ce60m",
"is_active": true,
"created_at": "2026-02-26T16:01:01.051127",
"last_login": "2026-02-26T16:11:48.431566"
},
"fronttest": {
"id": "059c564b78528d334f5ac4ecce3ea894",
"username": "fronttest",
"email": "front@test.com",
"full_name": null,
"hashed_password": "$2b$12$qkZxcN9peGfWSj59ULm6S.5ROtFlF7fGXJpypD7cQ0N9TzDRl93z.",
"is_active": true,
"created_at": "2026-02-28T09:41:38.411958",
"last_login": "2026-03-01T16:24:30.918490"
}
}
-62
View File
@@ -1,62 +0,0 @@
{
"2293bca2-c1c2-4e4f-8862-c4a6601f2b6f": {
"id": "2293bca2-c1c2-4e4f-8862-c4a6601f2b6f",
"user_id": "test_user_1",
"anime_title": "Test Anime",
"anime_url": "https://anime-sama.si/catalogue/test/vostfr/",
"provider_id": "animesama",
"lang": "vostfr",
"last_checked": "2026-03-24T08:45:18.470468",
"last_episode_downloaded": 0,
"total_episodes": null,
"auto_download": true,
"quality_preference": "auto",
"status": "active",
"poster_image": null,
"cover_image": null,
"synopsis": null,
"genres": [],
"added_at": "2026-01-29T21:53:38.078765",
"updated_at": "2026-03-24T08:45:18.470487"
},
"a5270097-d883-45b9-ad86-538a39c51e91": {
"id": "a5270097-d883-45b9-ad86-538a39c51e91",
"user_id": "059c564b78528d334f5ac4ecce3ea894",
"anime_title": "Frieren",
"anime_url": "https://anime-sama.tv/catalogue/frieren/saison1/vostfr/",
"provider_id": "anime-sama",
"lang": "vostfr",
"last_checked": "2026-03-24T08:45:18.849516",
"last_episode_downloaded": 28,
"total_episodes": 6,
"auto_download": true,
"quality_preference": "auto",
"status": "active",
"poster_image": null,
"cover_image": null,
"synopsis": null,
"genres": [],
"added_at": "2026-02-28T09:42:38.806576",
"updated_at": "2026-03-24T08:45:18.849533"
},
"944b598b-2bc8-4cd8-8a7b-8d84b7342c26": {
"id": "944b598b-2bc8-4cd8-8a7b-8d84b7342c26",
"user_id": "4eaae75f1df2f52bda44f6b18a400542",
"anime_title": "Frieren",
"anime_url": "https://anime-sama.tv/catalogue/frieren/saison1/vostfr/",
"provider_id": "anime-sama",
"lang": "vostfr",
"last_checked": "2026-03-24T08:45:19.136113",
"last_episode_downloaded": 28,
"total_episodes": 6,
"auto_download": true,
"quality_preference": "auto",
"status": "active",
"poster_image": "https://raw.githubusercontent.com/Anime-Sama/IMG/img/contenu/frieren0.jpg",
"cover_image": null,
"synopsis": null,
"genres": [],
"added_at": "2026-02-28T15:47:09.168943",
"updated_at": "2026-03-24T08:45:19.136131"
}
}
+2 -2
View File
@@ -1,6 +1,6 @@
{ {
"check_interval_hours": 6, "check_interval_hours": 12,
"auto_download_enabled": true, "auto_download_enabled": false,
"max_concurrent_auto_downloads": 2, "max_concurrent_auto_downloads": 2,
"notify_on_new_episodes": false, "notify_on_new_episodes": false,
"include_completed_anime": false "include_completed_anime": false
+5
View File
@@ -52,6 +52,11 @@ episode_checker.set_download_manager(download_manager)
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""Initialize services on application startup""" """Initialize services on application startup"""
# Create database tables if they don't exist
from app.database import create_db_and_tables
create_db_and_tables()
logger.info("Database tables initialized")
from app.sonarr_handler import get_sonarr_handler from app.sonarr_handler import get_sonarr_handler
sonarr_handler = get_sonarr_handler() sonarr_handler = get_sonarr_handler()
+1
View File
@@ -10,6 +10,7 @@ aiohttp==3.11.11
beautifulsoup4==4.12.3 beautifulsoup4==4.12.3
lxml==5.3.0 lxml==5.3.0
jieba==0.42.1 jieba==0.42.1
sqlmodel==0.0.22
# Testing dependencies # Testing dependencies
pytest==8.3.4 pytest==8.3.4
+54
View File
@@ -11,12 +11,66 @@ from unittest.mock import Mock, AsyncMock, patch
import sys import sys
import os import os
import sys
import os
# Ensure the project root is in the Python path # Ensure the project root is in the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# FORCE DATABASE_URL to in-memory for ALL tests before ANY app imports
os.environ["DATABASE_URL"] = "sqlite://"
from app.models import DownloadTask, DownloadStatus, DownloadRequest, HostType from app.models import DownloadTask, DownloadStatus, DownloadRequest, HostType
from app.favorites import FavoritesManager from app.favorites import FavoritesManager
from app.download_manager import DownloadManager from app.download_manager import DownloadManager
from sqlmodel import SQLModel, create_engine, Session
@pytest.fixture(scope="session", autouse=True)
def init_db():
"""Initialize the in-memory database once for the test session"""
from app.database import engine
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(name="engine")
def engine_fixture():
"""Returns the global test engine"""
from app.database import engine
return engine
@pytest.fixture(name="session")
def session_fixture(engine):
"""Create a temporary database session for testing"""
# Clear and recreate tables for each test to ensure isolation
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(autouse=True)
def mock_db(engine):
"""Ensure each test starts with fresh tables"""
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.fixture
def user_manager():
"""Create a UserManager instance"""
from app.auth import UserManager
return UserManager()
@pytest.fixture
def watchlist_manager():
"""Create a WatchlistManager instance"""
from app.watchlist import WatchlistManager
return WatchlistManager()
def pytest_configure(config): def pytest_configure(config):
+72 -297
View File
@@ -1,77 +1,39 @@
""" """
Unit tests for authentication system (app/auth.py) Unit tests for authentication system (app/auth.py)
Tests JWT tokens, user management, and password hashing Tests JWT tokens, user management, and password hashing with SQLModel support
""" """
import pytest import pytest
import json
from pathlib import Path
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import patch, Mock from unittest.mock import patch
from app.auth import UserManager, create_access_token, verify_token, get_user_from_token from app.auth import UserManager, create_access_token, verify_token, get_user_from_token
from app.models.auth import UserTable
@pytest.mark.skip(reason="Test does not match current implementation")
class TestUserManager: class TestUserManager:
"""Tests for UserManager class""" """Tests for UserManager class using SQLModel"""
@pytest.fixture
def temp_users_file(self, temp_dir):
"""Create a temporary users.json file"""
return temp_dir / "users.json"
@pytest.fixture
def user_manager(self, temp_users_file):
"""Create a UserManager instance with temporary storage"""
manager = UserManager(json_path=str(temp_users_file))
yield manager
# Cleanup
if temp_users_file.exists():
temp_users_file.unlink()
def test_user_manager_init_creates_file(self, user_manager, temp_users_file):
"""Test that UserManager creates the users file on init"""
assert temp_users_file.exists()
data = json.loads(temp_users_file.read_text())
assert "users" in data
assert isinstance(data["users"], dict)
def test_user_manager_init_existing_file(self, temp_users_file):
"""Test UserManager initialization with existing file"""
# Create a file with existing data
existing_data = {
"users": {
"existing_user": {
"username": "existing_user",
"password_hash": "hash",
"created_at": "2024-01-01T00:00:00",
"last_login": None
}
}
}
temp_users_file.write_text(json.dumps(existing_data))
manager = UserManager(json_path=str(temp_users_file))
# Should load existing data
assert "existing_user" in manager.users
def test_create_user_success(self, user_manager): def test_create_user_success(self, user_manager):
"""Test successful user creation""" """Test successful user creation"""
user = user_manager.create_user("testuser", "password123") user = user_manager.create_user("testuser", "password123")
assert user["username"] == "testuser" assert user.username == "testuser"
assert "password_hash" in user assert hasattr(user, "hashed_password")
assert "created_at" in user assert user.created_at is not None
assert user["last_login"] is None assert user.last_login is None
assert "testuser" in user_manager.users
# Verify it's in the database
db_user = user_manager.get_user("testuser")
assert db_user is not None
assert db_user.username == "testuser"
def test_create_user_hashing(self, user_manager): def test_create_user_hashing(self, user_manager):
"""Test that passwords are properly hashed with bcrypt""" """Test that passwords are properly hashed with bcrypt"""
user = user_manager.create_user("testuser", "password123") user = user_manager.create_user("testuser", "password123")
# Hash should not be the plain password # Hash should not be the plain password
assert user["password_hash"] != "password123" assert user.hashed_password != "password123"
# Bcrypt hashes start with $2b$ # Bcrypt hashes start with $2b$
assert user["password_hash"].startswith("$2b$") assert user.hashed_password.startswith("$2b$")
# Hash should be 60 characters (bcrypt standard) # Hash should be 60 characters (bcrypt standard)
assert len(user["password_hash"]) == 60 assert len(user.hashed_password) == 60
def test_create_user_duplicate(self, user_manager): def test_create_user_duplicate(self, user_manager):
"""Test that duplicate usernames are rejected""" """Test that duplicate usernames are rejected"""
@@ -79,26 +41,19 @@ class TestUserManager:
with pytest.raises(ValueError, match="already exists"): with pytest.raises(ValueError, match="already exists"):
user_manager.create_user("testuser", "different456") user_manager.create_user("testuser", "different456")
def test_create_user_short_password(self, user_manager):
"""Test that short passwords are rejected"""
with pytest.raises(ValueError, match="at least 6 characters"):
user_manager.create_user("testuser", "short")
def test_create_user_password_truncation(self, user_manager): def test_create_user_password_truncation(self, user_manager):
"""Test that passwords longer than 72 bytes are truncated""" """Test that passwords longer than 72 bytes are truncated (bcrypt limit)"""
# Bcrypt has a 72-byte limit
long_password = "a" * 100 long_password = "a" * 100
user = user_manager.create_user("testuser", long_password) user = user_manager.create_user("testuser", long_password)
# Should succeed (password truncated internally) assert user.username == "testuser"
assert user["username"] == "testuser"
def test_authenticate_user_success(self, user_manager): def test_authenticate_user_success(self, user_manager):
"""Test successful user authentication""" """Test successful user authentication"""
user_manager.create_user("testuser", "password123") user_manager.create_user("testuser", "password123")
user = user_manager.authenticate_user("testuser", "password123") user = user_manager.authenticate_user("testuser", "password123")
assert user is not None assert user is not None
assert user["username"] == "testuser" assert user.username == "testuser"
assert user["last_login"] is not None assert user.last_login is not None
def test_authenticate_user_wrong_password(self, user_manager): def test_authenticate_user_wrong_password(self, user_manager):
"""Test authentication with wrong password""" """Test authentication with wrong password"""
@@ -114,263 +69,83 @@ class TestUserManager:
def test_authenticate_updates_last_login(self, user_manager): def test_authenticate_updates_last_login(self, user_manager):
"""Test that authentication updates last_login timestamp""" """Test that authentication updates last_login timestamp"""
user_manager.create_user("testuser", "password123") user_manager.create_user("testuser", "password123")
user_before = user_manager.users["testuser"] user_before = user_manager.get_user("testuser")
assert user_before["last_login"] is None assert user_before.last_login is None
user_manager.authenticate_user("testuser", "password123") user_manager.authenticate_user("testuser", "password123")
user_after = user_manager.users["testuser"] user_after = user_manager.get_user("testuser")
assert user_after["last_login"] is not None assert user_after.last_login is not None
def test_get_user(self, user_manager): def test_get_user(self, user_manager):
"""Test getting a user by username""" """Test getting a user by username"""
user_manager.create_user("testuser", "password123") user_manager.create_user("testuser", "password123")
user = user_manager.get_user("testuser") user = user_manager.get_user("testuser")
assert user is not None assert user is not None
assert user["username"] == "testuser" assert user.username == "testuser"
def test_get_user_by_id(self, user_manager):
"""Test getting a user by ID"""
user = user_manager.create_user("testuser", "password123")
user_id = user.id
db_user = user_manager.get_user_by_id(user_id)
assert db_user is not None
assert db_user.username == "testuser"
def test_get_user_nonexistent(self, user_manager): def test_get_user_nonexistent(self, user_manager):
"""Test getting a non-existent user""" """Test getting a non-existent user"""
user = user_manager.get_user("nonexistent") user = user_manager.get_user("nonexistent")
assert user is None assert user is None
def test_update_user_last_login(self, user_manager): def test_update_user(self, user_manager):
"""Test updating user's last login timestamp""" """Test updating user information"""
user_manager.create_user("testuser", "password123") user = user_manager.create_user("testuser", "password123")
user_manager.update_last_login("testuser") updated = user_manager.update_user(user.id, {"full_name": "New Name", "email": "new@example.com"})
user = user_manager.users["testuser"] assert updated.full_name == "New Name"
assert user["last_login"] is not None assert updated.email == "new@example.com"
def test_deprecated_scheme_migration(self, user_manager): db_user = user_manager.get_user("testuser")
"""Test migration from deprecated password schemes""" assert db_user.full_name == "New Name"
# This tests the passlib auto-migration feature
# In practice, this is handled by passlib automatically
user_manager.create_user("testuser", "password123")
user = user_manager.users["testuser"]
# Should use bcrypt scheme
assert user["password_hash"].startswith("$2b$")
@pytest.mark.skip(reason="Test does not match current implementation") class TestJWTToken:
class TestJWTTokens: """Tests for JWT token functions"""
"""Tests for JWT token creation and verification"""
def test_create_access_token(self): def test_create_access_token(self):
"""Test JWT token creation""" """Test creating an access token"""
token = create_access_token(data={"sub": "testuser"}, expires_delta=timedelta(minutes=30)) token = create_access_token({"sub": "testuser"})
assert isinstance(token, str) assert isinstance(token, str)
# JWT tokens have 3 parts separated by dots assert len(token) > 0
assert len(token.split(".")) == 3
def test_create_token_default_expiration(self): def test_verify_token_success(self):
"""Test token creation with default expiration"""
token = create_access_token(data={"sub": "testuser"})
assert isinstance(token, str)
def test_verify_token_valid(self):
"""Test verifying a valid token""" """Test verifying a valid token"""
token = create_access_token(data={"sub": "testuser"}) token = create_access_token({"sub": "testuser"})
payload = verify_token(token) username = verify_token(token)
assert payload is not None assert username == "testuser"
assert payload.get("sub") == "testuser"
@pytest.mark.skip(reason="Problematic mock with datetime and jose library")
def test_verify_token_expired(self):
"""Test verifying an expired token"""
with patch('app.auth.datetime') as mock_datetime:
# Set fixed time
now = datetime.utcnow()
mock_datetime.utcnow.return_value = now
# Create token that expires in 1 minute
token = create_access_token({"sub": "testuser"}, expires_delta=timedelta(minutes=1))
# Move time forward by 2 minutes
mock_datetime.utcnow.return_value = now + timedelta(minutes=2)
username = verify_token(token)
assert username is None
def test_verify_token_invalid(self): def test_verify_token_invalid(self):
"""Test verifying an invalid token""" """Test verifying an invalid token"""
payload = verify_token("invalid.token.here") username = verify_token("invalid-token")
assert payload is None assert username is None
def test_verify_token_expired(self): def test_get_user_from_token(self):
"""Test verifying an expired token""" """Test get_user_from_token alias"""
# Create a token that's already expired token = create_access_token({"sub": "testuser"})
token = create_access_token(
data={"sub": "testuser"},
expires_delta=timedelta(seconds=-1) # Expired
)
payload = verify_token(token)
# Should return None for expired token
assert payload is None
def test_token_contains_username(self):
"""Test that token contains the username in 'sub' claim"""
token = create_access_token(data={"sub": "testuser"})
payload = verify_token(token)
assert payload["sub"] == "testuser"
def test_token_with_custom_claims(self):
"""Test token creation with custom claims"""
token = create_access_token(data={"sub": "testuser", "role": "admin"})
payload = verify_token(token)
assert payload["sub"] == "testuser"
assert payload["role"] == "admin"
def test_get_user_from_token_valid(self):
"""Test getting user from valid token"""
token = create_access_token(data={"sub": "testuser"})
username = get_user_from_token(token) username = get_user_from_token(token)
assert username == "testuser" assert username == "testuser"
def test_get_user_from_token_invalid(self):
"""Test getting user from invalid token"""
username = get_user_from_token("invalid.token")
assert username is None
def test_get_user_from_token_no_sub(self):
"""Test getting user from token without 'sub' claim"""
# Create token without 'sub' claim
token = create_access_token(data={"user": "testuser"})
username = get_user_from_token(token)
assert username is None
def test_different_secrets(self):
"""Test that tokens can't be verified with different secrets"""
token = create_access_token(data={"sub": "testuser"})
# Try to verify with different secret (by mocking)
with patch('app.auth.JWT_SECRET_KEY', 'different-secret'):
payload = verify_token(token)
# Should fail verification
assert payload is None
@pytest.mark.skip(reason="Test does not match current implementation")
class TestTokenExpiration:
"""Tests for token expiration handling"""
def test_token_expiration_time(self):
"""Test that token expiration time is correct"""
from app.auth import ACCESS_TOKEN_EXPIRE_MINUTES
# Create token with custom expiration
expires = timedelta(minutes=30)
token = create_access_token(data={"sub": "testuser"}, expires_delta=expires)
# Token should be valid immediately
payload = verify_token(token)
assert payload is not None
def test_default_expiration_from_config(self):
"""Test that default expiration matches configuration"""
from app.config import get_settings
settings = get_settings()
# Just verify the setting exists
assert hasattr(settings, 'ACCESS_TOKEN_EXPIRE_MINUTES') or 'ACCESS_TOKEN_EXPIRE_MINUTES' in dir(settings)
@pytest.mark.skip(reason="Test does not match current implementation")
class TestPasswordSecurity:
"""Tests for password handling security"""
def test_password_not_stored_plaintext(self, user_manager):
"""Test that passwords are never stored in plain text"""
user_manager.create_user("testuser", "password123")
user_data = user_manager.users["testuser"]
assert "password" not in user_data
assert "password_hash" in user_data
assert user_data["password_hash"] != "password123"
def test_password_case_sensitive(self, user_manager):
"""Test that password authentication is case-sensitive"""
user_manager.create_user("testuser", "Password123")
# Wrong case should fail
user = user_manager.authenticate_user("testuser", "password123")
assert user is None
def test_different_users_same_password(self, user_manager):
"""Test that different users with same password have different hashes"""
# Bcrypt uses salt, so hashes should be different
user1 = user_manager.create_user("user1", "samepassword")
user2 = user_manager.create_user("user2", "samepassword")
assert user1["password_hash"] != user2["password_hash"]
def test_password_hash_algorithm(self, user_manager):
"""Test that bcrypt is used for password hashing"""
user = user_manager.create_user("testuser", "password123")
# Bcrypt hashes start with $2b$
assert user["password_hash"].startswith("$2b$")
@pytest.mark.skip(reason="Test does not match current implementation")
class TestUserDataPersistence:
"""Tests for user data persistence and file operations"""
@pytest.fixture
def user_manager_with_file(self, temp_dir):
"""Create a UserManager and allow file operations"""
users_file = temp_dir / "test_users.json"
manager = UserManager(json_path=str(users_file))
yield manager
if users_file.exists():
users_file.unlink()
def test_user_saved_to_file(self, user_manager_with_file, temp_dir):
"""Test that users are saved to file"""
users_file = temp_dir / "test_users.json"
manager = user_manager_with_file
manager.create_user("testuser", "password123")
# Read file directly
data = json.loads(users_file.read_text())
assert "testuser" in data["users"]
def test_multiple_users_persisted(self, user_manager_with_file, temp_dir):
"""Test that multiple users are persisted correctly"""
users_file = temp_dir / "test_users.json"
manager = user_manager_with_file
manager.create_user("user1", "password1")
manager.create_user("user2", "password2")
manager.create_user("user3", "password3")
data = json.loads(users_file.read_text())
assert len(data["users"]) == 3
assert "user1" in data["users"]
assert "user2" in data["users"]
assert "user3" in data["users"]
def test_user_data_has_required_fields(self, user_manager_with_file):
"""Test that user data contains all required fields"""
manager = user_manager_with_file
user = manager.create_user("testuser", "password123")
required_fields = ["username", "password_hash", "created_at", "last_login"]
for field in required_fields:
assert field in user
def test_created_at_is_iso_format(self, user_manager_with_file):
"""Test that created_at is in ISO format"""
manager = user_manager_with_file
user = manager.create_user("testuser", "password123")
# Should be parseable as ISO datetime
datetime.fromisoformat(user["created_at"])
@pytest.mark.skip(reason="Test does not match current implementation")
class TestUsernameValidation:
"""Tests for username validation"""
@pytest.fixture
def user_manager(self, temp_dir):
users_file = temp_dir / "users.json"
manager = UserManager(json_path=str(users_file))
yield manager
if users_file.exists():
users_file.unlink()
def test_username_case_sensitive(self, user_manager):
"""Test that usernames are case-sensitive"""
user_manager.create_user("TestUser", "password123")
# Different case should be treated as different user
user2 = user_manager.create_user("testuser", "password456")
assert user2["username"] == "testuser"
# Both should exist
assert "TestUser" in user_manager.users
assert "testuser" in user_manager.users
def test_username_with_special_chars(self, user_manager):
"""Test usernames with special characters"""
# Should accept most characters
user = user_manager.create_user("user-123", "password123")
assert user["username"] == "user-123"
def test_username_with_spaces(self, user_manager):
"""Test usernames with spaces"""
user = user_manager.create_user("test user", "password123")
assert user["username"] == "test user"
+77 -382
View File
@@ -1,15 +1,11 @@
""" """
Unit tests for Watchlist system (app/watchlist.py, app/models/watchlist.py) Unit tests for Watchlist system (app/watchlist.py, app/models/watchlist.py) with SQLModel support
Tests watchlist CRUD operations, episode checking, and scheduler Tests watchlist CRUD operations, episode checking, and scheduler
""" """
import pytest import pytest
import json from datetime import datetime, timedelta
from pathlib import Path from unittest.mock import AsyncMock, patch
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
from app.watchlist import WatchlistManager
from app.models.watchlist import ( from app.models.watchlist import (
WatchlistItem,
WatchlistItemCreate, WatchlistItemCreate,
WatchlistItemUpdate, WatchlistItemUpdate,
WatchlistStatus, WatchlistStatus,
@@ -18,23 +14,8 @@ from app.models.watchlist import (
) )
@pytest.mark.skip(reason="Tests do not match current implementation")
class TestWatchlistManager: class TestWatchlistManager:
"""Tests for WatchlistManager class""" """Tests for WatchlistManager class using SQLModel"""
@pytest.fixture
def temp_watchlist_file(self, temp_dir):
"""Create a temporary watchlist.json file"""
return temp_dir / "watchlist.json"
@pytest.fixture
def watchlist_manager(self, temp_watchlist_file):
"""Create a WatchlistManager instance with temporary storage"""
manager = WatchlistManager(db_file=str(temp_watchlist_file))
yield manager
# Cleanup
if temp_watchlist_file.exists():
temp_watchlist_file.unlink()
@pytest.fixture @pytest.fixture
def sample_watchlist_item(self): def sample_watchlist_item(self):
@@ -42,23 +23,17 @@ class TestWatchlistManager:
return WatchlistItemCreate( return WatchlistItemCreate(
anime_url="https://anime-sama.si/catalogue/test/s1/vostfr/", anime_url="https://anime-sama.si/catalogue/test/s1/vostfr/",
anime_title="Test Anime", anime_title="Test Anime",
provider="anime-sama", provider_id="animesama",
lang="vostfr", lang="vostfr",
quality_preference=QualityPreference.AUTO, quality_preference=QualityPreference.AUTO,
auto_download=True auto_download=True
) )
def test_watchlist_manager_init_creates_file(self, watchlist_manager, temp_watchlist_file):
"""Test that WatchlistManager creates the file on init"""
assert temp_watchlist_file.exists()
data = json.loads(temp_watchlist_file.read_text())
assert "items" in data
def test_add_item_success(self, watchlist_manager, sample_watchlist_item): def test_add_item_success(self, watchlist_manager, sample_watchlist_item):
"""Test adding an item to watchlist""" """Test adding an item to watchlist"""
item = watchlist_manager.add_item( item = watchlist_manager.add(
user_id="test_user", user_id="test_user",
item_data=sample_watchlist_item item_create=sample_watchlist_item
) )
assert item.id is not None assert item.id is not None
assert item.anime_title == "Test Anime" assert item.anime_title == "Test Anime"
@@ -66,418 +41,138 @@ class TestWatchlistManager:
assert item.user_id == "test_user" assert item.user_id == "test_user"
def test_add_item_duplicate(self, watchlist_manager, sample_watchlist_item): def test_add_item_duplicate(self, watchlist_manager, sample_watchlist_item):
"""Test that duplicate items are rejected""" """Test that duplicate items (same user and URL) return existing item"""
watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item) item1 = watchlist_manager.add(user_id="test_user", item_create=sample_watchlist_item)
with pytest.raises(ValueError, match="already exists"): item2 = watchlist_manager.add(user_id="test_user", item_create=sample_watchlist_item)
watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item) assert item1.id == item2.id
def test_get_items_empty(self, watchlist_manager): def test_get_all_empty(self, watchlist_manager):
"""Test getting items when watchlist is empty""" """Test getting items when watchlist is empty"""
items = watchlist_manager.get_items("test_user") items = watchlist_manager.get_all("test_user")
assert items == [] assert items == []
def test_get_items_with_data(self, watchlist_manager, sample_watchlist_item): def test_get_all_with_data(self, watchlist_manager, sample_watchlist_item):
"""Test getting items after adding one""" """Test getting items after adding one"""
watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item) watchlist_manager.add(user_id="test_user", item_create=sample_watchlist_item)
items = watchlist_manager.get_items("test_user") items = watchlist_manager.get_all("test_user")
assert len(items) == 1 assert len(items) == 1
assert items[0].anime_title == "Test Anime" assert items[0].anime_title == "Test Anime"
def test_get_items_by_status(self, watchlist_manager): def test_get_all_by_status(self, watchlist_manager):
"""Test filtering items by status""" """Test filtering items by status"""
from app.models.watchlist import WatchlistItemCreate
# Add items with different statuses # Add items with different statuses
item1 = WatchlistItemCreate( item1 = WatchlistItemCreate(
anime_url="https://anime-sama.si/test1/", anime_url="https://anime-sama.si/test1/",
anime_title="Anime 1", anime_title="Anime 1",
provider="anime-sama", provider_id="animesama",
lang="vostfr" lang="vostfr"
) )
item2 = WatchlistItemCreate( item2 = WatchlistItemCreate(
anime_url="https://anime-sama.si/test2/", anime_url="https://anime-sama.si/test2/",
anime_title="Anime 2", anime_title="Anime 2",
provider="anime-sama", provider_id="animesama",
lang="vostfr" lang="vostfr"
) )
watchlist_manager.add_item(user_id="test_user", item_data=item1) watchlist_manager.add(user_id="test_user", item_create=item1)
item2_id = watchlist_manager.add_item(user_id="test_user", item_data=item2).id item2_obj = watchlist_manager.add(user_id="test_user", item_create=item2)
# Pause one item # Pause one item
watchlist_manager.update_item( watchlist_manager.update(
user_id="test_user", item_id=item2_obj.id,
item_id=item2_id, update_data=WatchlistItemUpdate(status=WatchlistStatus.PAUSED)
item_data=WatchlistItemUpdate(status=WatchlistStatus.PAUSED)
) )
# Get only active items # Get only active items
active_items = watchlist_manager.get_items("test_user", status=WatchlistStatus.ACTIVE) active_items = watchlist_manager.get_all("test_user", status=WatchlistStatus.ACTIVE)
assert len(active_items) == 1 assert len(active_items) == 1
assert active_items[0].anime_title == "Anime 1" assert active_items[0].anime_title == "Anime 1"
# Get only paused items # Get only paused items
paused_items = watchlist_manager.get_items("test_user", status=WatchlistStatus.PAUSED) paused_items = watchlist_manager.get_all("test_user", status=WatchlistStatus.PAUSED)
assert len(paused_items) == 1 assert len(paused_items) == 1
assert paused_items[0].anime_title == "Anime 2" assert paused_items[0].anime_title == "Anime 2"
def test_get_item_by_id(self, watchlist_manager, sample_watchlist_item): def test_get_by_id(self, watchlist_manager, sample_watchlist_item):
"""Test getting a specific item by ID""" """Test getting a specific item by ID"""
item = watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item) item = watchlist_manager.add(user_id="test_user", item_create=sample_watchlist_item)
retrieved = watchlist_manager.get_item(user_id="test_user", item_id=item.id) retrieved = watchlist_manager.get_by_id(item_id=item.id)
assert retrieved is not None assert retrieved is not None
assert retrieved.id == item.id assert retrieved.id == item.id
assert retrieved.anime_title == "Test Anime" assert retrieved.anime_title == "Test Anime"
def test_get_item_by_id_not_found(self, watchlist_manager): def test_get_by_id_not_found(self, watchlist_manager):
"""Test getting non-existent item""" """Test getting non-existent item"""
item = watchlist_manager.get_item(user_id="test_user", item_id="nonexistent") item = watchlist_manager.get_by_id(item_id="nonexistent")
assert item is None assert item is None
def test_update_item(self, watchlist_manager, sample_watchlist_item): def test_update_item(self, watchlist_manager, sample_watchlist_item):
"""Test updating an item""" """Test updating an item"""
item = watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item) item = watchlist_manager.add(user_id="test_user", item_create=sample_watchlist_item)
updated = watchlist_manager.update_item( updated = watchlist_manager.update(
user_id="test_user",
item_id=item.id, item_id=item.id,
item_data=WatchlistItemUpdate( update_data=WatchlistItemUpdate(
quality_preference=QualityPreference.FULLHD quality_preference=QualityPreference.P1080
) )
) )
assert updated.quality_preference == QualityPreference.FULLHD assert updated.quality_preference == QualityPreference.P1080
assert updated.anime_title == "Test Anime" # Unchanged assert updated.anime_title == "Test Anime"
def test_update_item_not_found(self, watchlist_manager):
"""Test updating non-existent item"""
with pytest.raises(ValueError, match="not found"):
watchlist_manager.update_item(
user_id="test_user",
item_id="nonexistent",
item_data=WatchlistItemUpdate()
)
def test_delete_item(self, watchlist_manager, sample_watchlist_item): def test_delete_item(self, watchlist_manager, sample_watchlist_item):
"""Test deleting an item""" """Test deleting an item"""
item = watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item) item = watchlist_manager.add(user_id="test_user", item_create=sample_watchlist_item)
watchlist_manager.delete_item(user_id="test_user", item_id=item.id) assert len(watchlist_manager.get_all("test_user")) == 1
# Should be deleted success = watchlist_manager.delete(item.id)
items = watchlist_manager.get_items("test_user") assert success is True
assert len(items) == 0 assert len(watchlist_manager.get_all("test_user")) == 0
def test_delete_item_not_found(self, watchlist_manager): def test_get_due_items(self, watchlist_manager, sample_watchlist_item):
"""Test deleting non-existent item""" """Test getting items due for checking"""
with pytest.raises(ValueError, match="not found"): # Set interval to 1 hour
watchlist_manager.delete_item(user_id="test_user", item_id="nonexistent") watchlist_manager.update_settings(WatchlistSettings(check_interval_hours=1))
def test_pause_item(self, watchlist_manager, sample_watchlist_item): # Add an item never checked
"""Test pausing an item""" item1 = watchlist_manager.add(user_id="user1", item_create=sample_watchlist_item)
item = watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item)
paused = watchlist_manager.pause_item(user_id="test_user", item_id=item.id)
assert paused.status == WatchlistStatus.PAUSED # Add an item checked recently
item2_data = WatchlistItemCreate(
def test_resume_item(self, watchlist_manager, sample_watchlist_item):
"""Test resuming a paused item"""
item = watchlist_manager.add_item(user_id="test_user", item_data=sample_watchlist_item)
# Pause first
watchlist_manager.pause_item(user_id="test_user", item_id=item.id)
# Resume
resumed = watchlist_manager.resume_item(user_id="test_user", item_id=item.id)
assert resumed.status == WatchlistStatus.ACTIVE
def test_get_stats(self, watchlist_manager):
"""Test getting watchlist statistics"""
from app.models.watchlist import WatchlistItemCreate
# Add multiple items
for i in range(3):
item = WatchlistItemCreate(
anime_url=f"https://anime-sama.si/test{i}/",
anime_title=f"Anime {i}",
provider="anime-sama",
lang="vostfr"
)
watchlist_manager.add_item(user_id="test_user", item_data=item)
stats = watchlist_manager.get_stats("test_user")
assert stats["total"] == 3
assert stats["by_status"]["active"] == 3
def test_multi_user_isolation(self, watchlist_manager):
"""Test that different users have separate watchlists"""
from app.models.watchlist import WatchlistItemCreate
item1 = WatchlistItemCreate(
anime_url="https://anime-sama.si/test1/",
anime_title="Anime 1",
provider="anime-sama",
lang="vostfr"
)
item2 = WatchlistItemCreate(
anime_url="https://anime-sama.si/test2/", anime_url="https://anime-sama.si/test2/",
anime_title="Anime 2", anime_title="Anime 2",
provider="anime-sama", provider_id="animesama"
lang="vostfr"
) )
item2 = watchlist_manager.add(user_id="user1", item_create=item2_data)
watchlist_manager.update_last_checked(item2.id)
watchlist_manager.add_item(user_id="user1", item_data=item1) # Add an item checked long ago
watchlist_manager.add_item(user_id="user2", item_data=item2) item3_data = WatchlistItemCreate(
anime_url="https://anime-sama.si/test3/",
# Each user should only see their own items anime_title="Anime 3",
user1_items = watchlist_manager.get_items("user1") provider_id="animesama"
user2_items = watchlist_manager.get_items("user2")
assert len(user1_items) == 1
assert len(user2_items) == 1
assert user1_items[0].anime_title == "Anime 1"
assert user2_items[0].anime_title == "Anime 2"
@pytest.mark.skip(reason="Tests do not match current implementation")
class TestWatchlistItemModel:
"""Tests for WatchlistItem Pydantic model"""
@pytest.mark.skip(reason="Test does not match current implementation")
def test_watchlist_item_creation(self):
"""Test creating a WatchlistItem"""
item = WatchlistItem(
id="test-id",
user_id="test_user",
anime_url="https://anime-sama.si/test/",
anime_title="Test Anime",
provider="anime-sama",
lang="vostfr",
quality_preference=QualityPreference.AUTO,
auto_download=True,
status=WatchlistStatus.ACTIVE,
last_checked=None,
created_at=datetime.now()
) )
assert item.anime_title == "Test Anime" item3 = watchlist_manager.add(user_id="user1", item_create=item3_data)
assert item.status == WatchlistStatus.ACTIVE with patch('app.watchlist.datetime') as mock_dt:
# Set last checked to 2 hours ago
mock_dt.now.return_value = datetime.now() - timedelta(hours=2)
watchlist_manager.update_last_checked(item3.id)
@pytest.mark.skip(reason="Test does not match current implementation") due_items = watchlist_manager.get_due_items()
def test_quality_preference_enum(self): # Should include item1 (never checked) and item3 (checked 2h ago)
"""Test QualityPreference enum values""" due_ids = [i.id for i in due_items]
assert QualityPreference.AUTO == "auto" assert item1.id in due_ids
assert QualityPreference.FULLHD == "1080p" assert item3.id in due_ids
assert QualityPreference.HD == "720p" assert item2.id not in due_ids
assert QualityPreference.SD == "480p"
def test_watchlist_status_enum(self):
"""Test WatchlistStatus enum values"""
assert WatchlistStatus.ACTIVE == "active"
assert WatchlistStatus.PAUSED == "paused"
assert WatchlistStatus.COMPLETED == "completed"
assert WatchlistStatus.ARCHIVED == "archived"
class TestWatchlistSettings: class TestWatchlistSettings:
"""Tests for WatchlistSettings model and management""" """Tests for WatchlistSettings management"""
@pytest.fixture def test_update_settings(self, watchlist_manager):
def temp_settings_file(self, temp_dir): """Test updating settings"""
"""Create a temporary watchlist_settings.json file""" new_settings = WatchlistSettings(check_interval_hours=12, auto_download_enabled=False)
return temp_dir / "watchlist_settings.json" updated = watchlist_manager.update_settings(new_settings)
assert updated.check_interval_hours == 12
def test_watchlist_settings_defaults(self): assert updated.auto_download_enabled is False
"""Test default values for WatchlistSettings""" assert watchlist_manager.settings.check_interval_hours == 12
settings = WatchlistSettings()
assert settings.auto_download_enabled is True
assert settings.check_interval_hours >= 1
assert settings.check_interval_hours <= 168
def test_watchlist_settings_validation(self):
"""Test WatchlistSettings validation"""
# Valid settings
settings = WatchlistSettings(
auto_download_enabled=True,
check_interval_hours=24,
default_quality=QualityPreference.AUTO
)
assert settings.check_interval_hours == 24
def test_watchlist_settings_invalid_interval(self):
"""Test that invalid check intervals are rejected"""
# Less than 1 hour
with pytest.raises(ValueError):
WatchlistSettings(check_interval_hours=0)
# More than 168 hours (1 week)
with pytest.raises(ValueError):
WatchlistSettings(check_interval_hours=200)
@pytest.mark.skip(reason="Tests do not match current implementation")
class TestEpisodeChecker:
"""Tests for EpisodeChecker functionality"""
@pytest.mark.asyncio
async def test_check_new_episodes(self):
"""Test checking for new episodes"""
from app.episode_checker import EpisodeChecker
# Mock the downloader
with patch('app.episode_checker.get_downloader') as mock_get_downloader:
mock_downloader = AsyncMock()
mock_downloader.get_episodes.return_value = [
{"episode_number": 1, "url": "ep1"},
{"episode_number": 2, "url": "ep2"},
{"episode_number": 3, "url": "ep3"}
]
mock_get_downloader.return_value = mock_downloader
checker = EpisodeChecker()
# Test episode checking logic
episodes = await mock_downloader.get_episodes(
"https://anime-sama.si/test/",
"vostfr"
)
assert len(episodes) == 3
assert episodes[2]["episode_number"] == 3
@pytest.mark.asyncio
async def test_episode_download_creation(self):
"""Test that new episodes trigger downloads when auto_download is enabled"""
# This would test the integration with download_manager
# For now, just test the logic flow
pass
@pytest.mark.skip(reason="Tests do not match current implementation")
class TestAutoDownloadScheduler:
"""Tests for AutoDownloadScheduler functionality"""
def test_scheduler_initialization(self):
"""Test scheduler initialization"""
from app.auto_download_scheduler import AutoDownloadScheduler
scheduler = AutoDownloadScheduler()
assert scheduler.is_running() is False
@pytest.mark.skip(reason="Test does not match current implementation")
def test_scheduler_start_stop(self):
"""Test starting and stopping scheduler"""
from app.auto_download_scheduler import AutoDownloadScheduler
scheduler = AutoDownloadScheduler()
# Start
scheduler.start()
assert scheduler.is_running() is True
# Stop
scheduler.stop()
assert scheduler.is_running() is False
@pytest.mark.skip(reason="Test does not match current implementation")
def test_scheduler_interval_validation(self):
"""Test that scheduler validates intervals"""
from app.auto_download_scheduler import AutoDownloadScheduler
scheduler = AutoDownloadScheduler()
# Valid interval
scheduler.set_interval(24) # 24 hours
assert scheduler.get_interval() == 24
# Invalid interval (should raise or clamp)
with pytest.raises(ValueError):
scheduler.set_interval(0) # Too small
with pytest.raises(ValueError):
scheduler.set_interval(200) # Too large
@pytest.mark.skip(reason="Tests do not match current implementation")
class TestWatchlistIntegration:
"""Integration tests for watchlist system"""
@pytest.fixture
def temp_watchlist_file(self, temp_dir):
"""Create a temporary watchlist.json file"""
return temp_dir / "watchlist.json"
@pytest.fixture
def watchlist_manager(self, temp_watchlist_file):
"""Create a WatchlistManager instance"""
manager = WatchlistManager(db_file=str(temp_watchlist_file))
yield manager
if temp_watchlist_file.exists():
temp_watchlist_file.unlink()
def test_full_workflow(self, watchlist_manager):
"""Test complete workflow: add -> pause -> resume -> delete"""
from app.models.watchlist import WatchlistItemCreate
# Add
item_data = WatchlistItemCreate(
anime_url="https://anime-sama.si/test/",
anime_title="Test Anime",
provider="anime-sama",
lang="vostfr"
)
item = watchlist_manager.add_item(user_id="test_user", item_data=item_data)
assert item.status == WatchlistStatus.ACTIVE
# Pause
paused = watchlist_manager.pause_item(user_id="test_user", item_id=item.id)
assert paused.status == WatchlistStatus.PAUSED
# Resume
resumed = watchlist_manager.resume_item(user_id="test_user", item_id=item.id)
assert resumed.status == WatchlistStatus.ACTIVE
# Delete
watchlist_manager.delete_item(user_id="test_user", item_id=item.id)
items = watchlist_manager.get_items("test_user")
assert len(items) == 0
def test_update_quality_preference_workflow(self, watchlist_manager):
"""Test updating quality preference"""
from app.models.watchlist import WatchlistItemCreate
item_data = WatchlistItemCreate(
anime_url="https://anime-sama.si/test/",
anime_title="Test Anime",
provider="anime-sama",
lang="vostfr",
quality_preference=QualityPreference.AUTO
)
item = watchlist_manager.add_item(user_id="test_user", item_data=item_data)
# Update to 1080p
updated = watchlist_manager.update_item(
user_id="test_user",
item_id=item.id,
item_data=WatchlistItemUpdate(quality_preference=QualityPreference.FULLHD)
)
assert updated.quality_preference == QualityPreference.FULLHD
def test_filter_by_status_workflow(self, watchlist_manager):
"""Test filtering items by different statuses"""
from app.models.watchlist import WatchlistItemCreate
# Add multiple items
for i, status in enumerate([WatchlistStatus.ACTIVE, WatchlistStatus.PAUSED, WatchlistStatus.COMPLETED]):
item_data = WatchlistItemCreate(
anime_url=f"https://anime-sama.si/test{i}/",
anime_title=f"Anime {i}",
provider="anime-sama",
lang="vostfr"
)
item = watchlist_manager.add_item(user_id="test_user", item_data=item_data)
# Update status
watchlist_manager.update_item(
user_id="test_user",
item_id=item.id,
item_data=WatchlistItemUpdate(status=status)
)
# Count by status
stats = watchlist_manager.get_stats("test_user")
assert stats["total"] == 3
assert stats["by_status"]["active"] == 1
assert stats["by_status"]["paused"] == 1
assert stats["by_status"]["completed"] == 1