feat: migrate persistence from JSON to SQLModel (Phase 1)
CI / Test (Python 3.11) (push) Has been cancelled
CI / Test (Python 3.12) (push) Has been cancelled
CI / Lint (push) Has been cancelled
CI / Type Check (push) Has been cancelled
CI / Summary (push) Has been cancelled

- Integrated SQLModel with SQLite for robust data persistence
- Refactored UserManager and WatchlistManager to use SQL queries
- Migrated models to SQLModel with relationships and primary keys
- Updated test suite with in-memory database isolation
- Removed deprecated JSON storage files
This commit is contained in:
root
2026-03-24 10:40:36 +00:00
parent d4d8d8a3b6
commit 29c7040b20
13 changed files with 596 additions and 1165 deletions
+84 -100
View File
@@ -1,120 +1,118 @@
"""User authentication and management system"""
"""User authentication and management system with SQLModel support"""
import json
import os
import hashlib
from datetime import datetime, timedelta
from typing import Optional, Dict
from typing import Optional, Dict, List
from jose import jwt
from passlib.context import CryptContext
import logging
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from sqlmodel import Session, select
from app.database import engine
from app.models.auth import UserTable
from app.config import get_settings
logger = logging.getLogger(__name__)
# Load settings at module level for easier mocking and access
settings = get_settings()
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Users database file
USERS_DB_FILE = "config/users.json"
class UserManager:
"""Manages user storage and authentication"""
"""Manages user storage and authentication using SQL database"""
def __init__(self, db_file: str = USERS_DB_FILE):
self.db_file = db_file
self.users: Dict[str, dict] = {}
self._load_users()
def __init__(self):
# Database connection is managed via engine and sessions
pass
def _load_users(self):
"""Load users from JSON file"""
try:
if os.path.exists(self.db_file):
with open(self.db_file, "r", encoding="utf-8") as f:
self.users = json.load(f)
logger.info(f"Loaded {len(self.users)} users from database")
except Exception as e:
logger.error(f"Error loading users: {e}")
self.users = {}
def _save_users(self):
try:
os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
temp_file = f"{self.db_file}.tmp"
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.users, f, indent=2, ensure_ascii=False, default=str)
os.replace(temp_file, self.db_file)
logger.info(f"Saved {len(self.users)} users to database")
except Exception as e:
logger.error(f"Error saving users: {e}")
def get_user(self, username: str) -> Optional[dict]:
def get_user(self, username: str) -> Optional[UserTable]:
"""Get user by username"""
return self.users.get(username)
with Session(engine) as session:
statement = select(UserTable).where(UserTable.username == username)
return session.exec(statement).first()
def get_user_by_id(self, user_id: str) -> Optional[dict]:
def get_user_by_id(self, user_id: str) -> Optional[UserTable]:
"""Get user by ID"""
for user in self.users.values():
if user.get("id") == user_id:
return user
return None
with Session(engine) as session:
statement = select(UserTable).where(UserTable.id == user_id)
return session.exec(statement).first()
def create_user(
self, username: str, password: str, email: str = None, full_name: str = None
) -> dict:
) -> UserTable:
"""Create a new user"""
if username in self.users:
raise ValueError(f"Username '{username}' already exists")
with Session(engine) as session:
# Check if user already exists
statement = select(UserTable).where(UserTable.username == username)
if session.exec(statement).first():
raise ValueError(f"Username '{username}' already exists")
# Truncate password to 72 bytes if necessary (bcrypt limitation)
password_bytes = password.encode("utf-8")
if len(password_bytes) > 72:
password = password_bytes[:72].decode("utf-8", errors="ignore")
# Truncate password to 72 bytes if necessary (bcrypt limitation)
password_bytes = password.encode("utf-8")
if len(password_bytes) > 72:
password = password_bytes[:72].decode("utf-8", errors="ignore")
# Hash password
hashed_password = pwd_context.hash(password)
# Hash password
hashed_password = pwd_context.hash(password)
# Create user
user = {
"id": hashlib.sha256(username.encode()).hexdigest()[:32],
"username": username,
"email": email,
"full_name": full_name,
"hashed_password": hashed_password,
"is_active": True,
"created_at": datetime.now().isoformat(),
"last_login": None,
}
# Create user
user = UserTable(
username=username,
email=email,
full_name=full_name,
hashed_password=hashed_password,
is_active=True,
created_at=datetime.now()
)
self.users[username] = user
self._save_users()
session.add(user)
session.commit()
session.refresh(user)
logger.info(f"Created user: {username}")
return user
logger.info(f"Created user: {username}")
return user
def authenticate_user(self, username: str, password: str) -> Optional[dict]:
def authenticate_user(self, username: str, password: str) -> Optional[UserTable]:
"""Authenticate user with username and password"""
user = self.get_user(username)
if not user:
return None
if not pwd_context.verify(password, user["hashed_password"]):
if not pwd_context.verify(password, user.hashed_password):
return None
# Update last login
user["last_login"] = datetime.now().isoformat()
self._save_users()
with Session(engine) as session:
db_user = session.get(UserTable, user.id)
if db_user:
db_user.last_login = datetime.now()
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user
return user
def update_last_login(self, username: str):
"""Update user's last login time"""
user = self.get_user(username)
if user:
user["last_login"] = datetime.now().isoformat()
self._save_users()
def update_user(self, user_id: str, update_data: dict) -> Optional[UserTable]:
"""Update user information"""
with Session(engine) as session:
db_user = session.get(UserTable, user_id)
if not db_user:
return None
for key, value in update_data.items():
if hasattr(db_user, key):
setattr(db_user, key, value)
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user
# Global user manager instance
@@ -131,27 +129,11 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def _get_jwt_config() -> dict:
"""Get JWT configuration from settings"""
from app.config import get_settings
settings = get_settings()
return {
"SECRET_KEY": settings.jwt_secret_key,
"ALGORITHM": settings.jwt_algorithm,
"ACCESS_TOKEN_EXPIRE_MINUTES": settings.access_token_expire_minutes,
"REFRESH_TOKEN_EXPIRE_DAYS": settings.refresh_token_expire_days,
}
def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
"""Create JWT access token"""
from jose import jwt
jwt_config = _get_jwt_config()
SECRET_KEY = jwt_config["SECRET_KEY"]
ALGORITHM = jwt_config["ALGORITHM"]
ACCESS_TOKEN_EXPIRE_MINUTES = jwt_config["ACCESS_TOKEN_EXPIRE_MINUTES"]
SECRET_KEY = settings.jwt_secret_key
ALGORITHM = settings.jwt_algorithm
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
to_encode = data.copy()
@@ -168,12 +150,10 @@ def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
def verify_token(token: str) -> Optional[str]:
"""Verify JWT token and return username"""
from jose import jwt
from jose.exceptions import JWTError
jwt_config = _get_jwt_config()
SECRET_KEY = jwt_config["SECRET_KEY"]
ALGORITHM = jwt_config["ALGORITHM"]
SECRET_KEY = settings.jwt_secret_key
ALGORITHM = settings.jwt_algorithm
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
@@ -189,7 +169,7 @@ def verify_token(token: str) -> Optional[str]:
get_user_from_token = verify_token
def get_current_user(credentials: HTTPAuthorizationCredentials) -> dict:
def get_current_user(credentials: HTTPAuthorizationCredentials) -> UserTable:
"""Get current user from JWT token"""
token = credentials.credentials
username = verify_token(token)
@@ -197,16 +177,19 @@ def get_current_user(credentials: HTTPAuthorizationCredentials) -> dict:
user = user_manager.get_user(username)
if not user:
raise HTTPException(status_code=401, detail="User not found")
if not user.get("is_active", True):
if not user.is_active:
raise HTTPException(status_code=401, detail="Inactive user")
return user
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
# Refresh tokens storage
REFRESH_TOKENS_FILE = "config/refresh_tokens.json"
def _load_refresh_tokens() -> Dict[str, dict]:
"""Load refresh tokens from file"""
import json
try:
if os.path.exists(REFRESH_TOKENS_FILE):
with open(REFRESH_TOKENS_FILE, 'r', encoding='utf-8') as f:
@@ -218,6 +201,7 @@ def _load_refresh_tokens() -> Dict[str, dict]:
def _save_refresh_tokens(tokens: Dict[str, dict]):
"""Save refresh tokens to file"""
import json
try:
os.makedirs(os.path.dirname(REFRESH_TOKENS_FILE), exist_ok=True)
with open(REFRESH_TOKENS_FILE, 'w', encoding='utf-8') as f:
+31
View File
@@ -0,0 +1,31 @@
"""Database configuration and session management using SQLModel"""
import os
from typing import Generator
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, create_engine
from app.config import get_settings
settings = get_settings()
# Database URL can be overridden by environment variable DATABASE_URL
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./ohm_streaming.db")
# Create the engine
# connect_args={"check_same_thread": False} is required for SQLite and FastAPI
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
def create_db_and_tables():
"""Create the database and tables based on the models"""
# Import all models here to ensure they are registered with SQLModel.metadata
from app.models.auth import UserTable
from app.models.watchlist import WatchlistItemTable
# Add other models as they are migrated
SQLModel.metadata.create_all(engine)
def get_session() -> Generator[Session, None, None]:
"""Dependency for getting a database session"""
with Session(engine) as session:
yield session
+36 -14
View File
@@ -1,15 +1,41 @@
"""Authentication models for user management"""
from pydantic import BaseModel, EmailStr, Field
from typing import Optional
"""Authentication models for user management with SQLModel support"""
import uuid
from pydantic import BaseModel, EmailStr, Field as PydanticField
from typing import Optional, List
from datetime import datetime
from sqlmodel import SQLModel, Field, Relationship
class UserCreate(BaseModel):
"""Schema for user registration"""
username: str = Field(..., min_length=3, max_length=50)
email: Optional[EmailStr] = None
password: str = Field(..., min_length=6)
class UserBase(SQLModel):
"""Base schema for user data"""
username: str = Field(index=True, unique=True, min_length=3, max_length=50)
email: Optional[str] = Field(default=None, index=True)
full_name: Optional[str] = None
is_active: bool = Field(default=True)
class UserTable(UserBase, table=True):
"""Database table for users"""
__tablename__ = "users"
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
primary_key=True,
index=True,
nullable=False
)
hashed_password: str
created_at: datetime = Field(default_factory=datetime.now)
last_login: Optional[datetime] = None
# Relationships
watchlist_items: List["WatchlistItemTable"] = Relationship(back_populates="user")
class UserCreate(UserBase):
"""Schema for user registration"""
password: str = PydanticField(..., min_length=6)
email: Optional[EmailStr] = None
class UserLogin(BaseModel):
@@ -18,13 +44,9 @@ class UserLogin(BaseModel):
password: str
class User(BaseModel):
"""Schema for user data"""
class User(UserBase):
"""Schema for user data (API Response)"""
id: str
username: str
email: Optional[str] = None
full_name: Optional[str] = None
is_active: bool = True
created_at: datetime
last_login: Optional[datetime] = None
+82 -46
View File
@@ -1,8 +1,11 @@
"""Pydantic models for Watchlist and Auto-Download system"""
from pydantic import BaseModel, Field
from typing import Optional, Literal
"""Models for Watchlist and Auto-Download system with SQLModel support"""
import uuid
import json
from pydantic import BaseModel, Field as PydanticField
from typing import Optional, Literal, List
from datetime import datetime
from enum import Enum
from sqlmodel import SQLModel, Field, Relationship, Column, String
class WatchlistStatus(str, Enum):
@@ -21,34 +24,80 @@ class QualityPreference(str, Enum):
P480 = "480p" # SD
class WatchlistItem(BaseModel):
"""An anime being tracked for automatic episode downloads"""
id: str = Field(..., description="Unique identifier (UUID)")
user_id: str = Field(..., description="User ID who owns this watchlist item")
anime_title: str = Field(..., description="Title of the anime")
anime_url: str = Field(..., description="URL to the anime page")
provider_id: str = Field(..., description="Provider ID (animesama, nekosama, etc.)")
lang: Literal["vostfr", "vf"] = Field(default="vostfr", description="Language preference")
class WatchlistItemBase(SQLModel):
"""Base schema for watchlist items"""
anime_title: str = Field(index=True)
anime_url: str
provider_id: str
lang: str = Field(default="vostfr")
# Tracking state
last_checked: Optional[datetime] = Field(None, description="Last time we checked for new episodes")
last_episode_downloaded: int = Field(default=0, description="Last episode number downloaded")
total_episodes: Optional[int] = Field(None, description="Total episodes if known")
last_checked: Optional[datetime] = None
last_episode_downloaded: int = Field(default=0)
total_episodes: Optional[int] = None
# Settings
auto_download: bool = Field(default=True, description="Automatically download new episodes")
quality_preference: QualityPreference = Field(default=QualityPreference.AUTO, description="Preferred quality")
status: WatchlistStatus = Field(default=WatchlistStatus.ACTIVE, description="Tracking status")
auto_download: bool = Field(default=True)
quality_preference: QualityPreference = Field(default=QualityPreference.AUTO)
status: WatchlistStatus = Field(default=WatchlistStatus.ACTIVE)
# Metadata
poster_image: Optional[str] = Field(None, description="URL to poster image")
cover_image: Optional[str] = Field(None, description="URL to cover image")
synopsis: Optional[str] = Field(None, description="Anime synopsis")
genres: list[str] = Field(default_factory=list, description="Anime genres")
poster_image: Optional[str] = None
cover_image: Optional[str] = None
synopsis: Optional[str] = None
# Timestamps
added_at: datetime = Field(default_factory=datetime.now, description="When added to watchlist")
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
added_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class WatchlistItemTable(WatchlistItemBase, table=True):
"""Database table for watchlist items"""
__tablename__ = "watchlist_items"
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
primary_key=True,
index=True,
nullable=False
)
user_id: str = Field(foreign_key="users.id", index=True)
# Store list as JSON string in SQLite
genres_json: Optional[str] = Field(default="[]", sa_column=Column(String))
@property
def genres(self) -> List[str]:
return json.loads(self.genres_json or "[]")
@genres.setter
def genres(self, value: List[str]):
self.genres_json = json.dumps(value or [])
# Relationships
user: Optional["UserTable"] = Relationship(back_populates="watchlist_items")
class WatchlistItem(BaseModel):
"""An anime being tracked for automatic episode downloads (API Response)"""
id: str
user_id: str
anime_title: str
anime_url: str
provider_id: str
lang: str
last_checked: Optional[datetime] = None
last_episode_downloaded: int = 0
total_episodes: Optional[int] = None
auto_download: bool = True
quality_preference: QualityPreference = QualityPreference.AUTO
status: WatchlistStatus = WatchlistStatus.ACTIVE
poster_image: Optional[str] = None
cover_image: Optional[str] = None
synopsis: Optional[str] = None
genres: List[str] = []
added_at: datetime
updated_at: datetime
class Config:
json_encoders = {
@@ -64,12 +113,10 @@ class WatchlistItemCreate(BaseModel):
lang: Literal["vostfr", "vf"] = "vostfr"
auto_download: bool = True
quality_preference: QualityPreference = QualityPreference.AUTO
# Optional metadata
poster_image: Optional[str] = None
cover_image: Optional[str] = None
synopsis: Optional[str] = None
genres: list[str] = []
genres: List[str] = []
class WatchlistItemUpdate(BaseModel):
@@ -96,26 +143,15 @@ class AutoDownloadResult(BaseModel):
watchlist_item_id: str
anime_title: str
new_episodes_found: int
episodes_downloaded: list[int] = Field(default_factory=list)
episodes_failed: list[tuple[int, str]] = Field(default_factory=list) # (episode_number, error_message)
checked_at: datetime = Field(default_factory=datetime.now)
episodes_downloaded: list[int] = PydanticField(default_factory=list)
episodes_failed: list[tuple[int, str]] = PydanticField(default_factory=list)
checked_at: datetime = PydanticField(default_factory=datetime.now)
class WatchlistSettings(BaseModel):
"""Global watchlist settings"""
check_interval_hours: int = Field(default=6, ge=1, le=168, description="Check interval (1-168 hours)")
auto_download_enabled: bool = Field(default=True, description="Global auto-download toggle")
max_concurrent_auto_downloads: int = Field(default=2, ge=1, le=10, description="Max concurrent auto-downloads")
notify_on_new_episodes: bool = Field(default=False, description="Send notifications for new episodes")
include_completed_anime: bool = Field(default=False, description="Check completed anime too")
class Config:
json_schema_extra = {
"example": {
"check_interval_hours": 6,
"auto_download_enabled": True,
"max_concurrent_auto_downloads": 2,
"notify_on_new_episodes": False,
"include_completed_anime": False
}
}
check_interval_hours: int = PydanticField(default=6, ge=1, le=168)
auto_download_enabled: bool = PydanticField(default=True)
max_concurrent_auto_downloads: int = PydanticField(default=2, ge=1, le=10)
notify_on_new_episodes: bool = PydanticField(default=False)
include_completed_anime: bool = PydanticField(default=False)
+146 -164
View File
@@ -1,4 +1,4 @@
"""Watchlist management system for automatic episode tracking and downloading"""
"""Watchlist management system for automatic episode tracking and downloading with SQLModel support"""
import json
import os
import uuid
@@ -7,8 +7,11 @@ from datetime import datetime, timedelta
from typing import List, Optional, Dict
from pathlib import Path
from sqlmodel import Session, select
from app.database import engine
from app.models.watchlist import (
WatchlistItem,
WatchlistItemTable,
WatchlistItemCreate,
WatchlistItemUpdate,
WatchlistStatus,
@@ -19,55 +22,18 @@ from app.models.watchlist import (
logger = logging.getLogger(__name__)
# Watchlist database file
WATCHLIST_DB_FILE = "config/watchlist.json"
# Settings file remains JSON for simplicity for now
WATCHLIST_SETTINGS_FILE = "config/watchlist_settings.json"
class WatchlistManager:
"""Manages user watchlist for automatic episode downloads"""
"""Manages user watchlist for automatic episode downloads using SQL database"""
def __init__(self, db_file: str = WATCHLIST_DB_FILE):
self.db_file = db_file
def __init__(self):
self.settings_file = WATCHLIST_SETTINGS_FILE
self.watchlist: Dict[str, WatchlistItem] = {}
self.settings: Optional[WatchlistSettings] = None
self._load_watchlist()
self._load_settings()
def _load_watchlist(self):
"""Load watchlist from JSON file"""
try:
if os.path.exists(self.db_file):
with open(self.db_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.watchlist = {
item_id: WatchlistItem(**item_data)
for item_id, item_data in data.items()
}
logger.info(f"Loaded {len(self.watchlist)} items from watchlist")
else:
self.watchlist = {}
logger.info("Watchlist database not found, starting with empty watchlist")
except Exception as e:
logger.error(f"Error loading watchlist: {e}")
self.watchlist = {}
def _save_watchlist(self):
try:
os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
data = {
item_id: item.model_dump(mode='json')
for item_id, item in self.watchlist.items()
}
temp_file = f"{self.db_file}.tmp"
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False, default=str)
os.replace(temp_file, self.db_file)
logger.debug(f"Saved {len(self.watchlist)} items to watchlist")
except Exception as e:
logger.error(f"Error saving watchlist: {e}")
def _load_settings(self):
"""Load watchlist settings from JSON file"""
try:
@@ -95,167 +61,183 @@ class WatchlistManager:
except Exception as e:
logger.error(f"Error saving settings: {e}")
def _to_api_model(self, db_item: WatchlistItemTable) -> WatchlistItem:
"""Convert database table model to API response model"""
data = db_item.model_dump()
data["genres"] = db_item.genres
return WatchlistItem(**data)
def get_all(self, user_id: Optional[str] = None, status: Optional[WatchlistStatus] = None) -> List[WatchlistItem]:
"""Get all watchlist items, optionally filtered by user and status"""
items = list(self.watchlist.values())
if user_id:
items = [item for item in items if item.user_id == user_id]
if status:
items = [item for item in items if item.status == status]
# Sort by added_at descending
items.sort(key=lambda x: x.added_at, reverse=True)
return items
with Session(engine) as session:
statement = select(WatchlistItemTable)
if user_id:
statement = statement.where(WatchlistItemTable.user_id == user_id)
if status:
statement = statement.where(WatchlistItemTable.status == status)
# Sort by added_at descending
statement = statement.order_by(WatchlistItemTable.added_at.desc())
db_items = session.exec(statement).all()
return [self._to_api_model(item) for item in db_items]
def get_by_id(self, item_id: str) -> Optional[WatchlistItem]:
"""Get a watchlist item by ID"""
return self.watchlist.get(item_id)
"""Get a specific watchlist item by ID"""
with Session(engine) as session:
db_item = session.get(WatchlistItemTable, item_id)
if db_item:
return self._to_api_model(db_item)
return None
def get_by_anime_url(self, anime_url: str, user_id: str) -> Optional[WatchlistItem]:
"""Get a watchlist item by anime URL and user ID"""
for item in self.watchlist.values():
if item.anime_url == anime_url and item.user_id == user_id:
return item
return None
with Session(engine) as session:
statement = select(WatchlistItemTable).where(
WatchlistItemTable.anime_url == anime_url,
WatchlistItemTable.user_id == user_id
)
db_item = session.exec(statement).first()
if db_item:
return self._to_api_model(db_item)
return None
def create(self, user_id: str, item_data: WatchlistItemCreate) -> WatchlistItem:
"""Create a new watchlist item"""
# Check if already exists
existing = self.get_by_anime_url(item_data.anime_url, user_id)
def add(self, user_id: str, item_create: WatchlistItemCreate) -> WatchlistItem:
"""Add a new anime to the watchlist"""
# Check if already in watchlist for this user
existing = self.get_by_anime_url(item_create.anime_url, user_id)
if existing:
raise ValueError(f"Anime already in watchlist (ID: {existing.id})")
return existing
# Create new item
item_id = str(uuid.uuid4())
now = datetime.now()
with Session(engine) as session:
# Create new item
db_item = WatchlistItemTable(
user_id=user_id,
anime_title=item_create.anime_title,
anime_url=item_create.anime_url,
provider_id=item_create.provider_id,
lang=item_create.lang,
auto_download=item_create.auto_download,
quality_preference=item_create.quality_preference,
poster_image=item_create.poster_image,
cover_image=item_create.cover_image,
synopsis=item_create.synopsis,
status=WatchlistStatus.ACTIVE,
added_at=datetime.now(),
updated_at=datetime.now(),
last_episode_downloaded=0
)
db_item.genres = item_create.genres
watchlist_item = WatchlistItem(
id=item_id,
user_id=user_id,
anime_title=item_data.anime_title,
anime_url=item_data.anime_url,
provider_id=item_data.provider_id,
lang=item_data.lang,
auto_download=item_data.auto_download,
quality_preference=item_data.quality_preference,
status=WatchlistStatus.ACTIVE,
poster_image=item_data.poster_image,
cover_image=item_data.cover_image,
synopsis=item_data.synopsis,
genres=item_data.genres,
added_at=now,
updated_at=now,
last_checked=None,
last_episode_downloaded=0,
total_episodes=None
)
session.add(db_item)
session.commit()
session.refresh(db_item)
logger.info(f"Added {db_item.anime_title} to watchlist for user {user_id}")
return self._to_api_model(db_item)
self.watchlist[item_id] = watchlist_item
self._save_watchlist()
logger.info(f"Added anime to watchlist: {watchlist_item.anime_title} (ID: {item_id})")
return watchlist_item
# Alias for backward compatibility if needed
add_item = add
def update(self, item_id: str, update_data) -> Optional[WatchlistItem]:
"""Update a watchlist item
"""Update a watchlist item"""
with Session(engine) as session:
db_item = session.get(WatchlistItemTable, item_id)
if not db_item:
return None
Args:
item_id: Item ID to update
update_data: WatchlistItemUpdate object or dict with fields to update
"""
item = self.watchlist.get(item_id)
if not item:
return None
# Handle both dict and WatchlistItemUpdate
if isinstance(update_data, dict):
update_dict = update_data
else:
update_dict = update_data.model_dump(exclude_unset=True)
# Handle both dict and WatchlistItemUpdate
if isinstance(update_data, dict):
update_dict = update_data
else:
update_dict = update_data.model_dump(exclude_unset=True)
for key, value in update_dict.items():
if hasattr(db_item, key):
setattr(db_item, key, value)
db_item.updated_at = datetime.now()
session.add(db_item)
session.commit()
session.refresh(db_item)
logger.info(f"Updated watchlist item: {item_id}")
return self._to_api_model(db_item)
# Update fields
for field, value in update_dict.items():
if value is not None:
setattr(item, field, value)
item.updated_at = datetime.now()
self._save_watchlist()
logger.info(f"Updated watchlist item: {item_id}")
return item
# Alias for backward compatibility
update_item = update
def delete(self, item_id: str) -> bool:
"""Delete a watchlist item"""
if item_id in self.watchlist:
del self.watchlist[item_id]
self._save_watchlist()
logger.info(f"Deleted watchlist item: {item_id}")
"""Remove an item from the watchlist"""
with Session(engine) as session:
db_item = session.get(WatchlistItemTable, item_id)
if not db_item:
return False
session.delete(db_item)
session.commit()
logger.info(f"Deleted item {item_id} from watchlist")
return True
return False
def update_check_time(self, item_id: str, last_episode: int) -> Optional[WatchlistItem]:
"""Update last_checked time and last_episode_downloaded"""
item = self.watchlist.get(item_id)
if not item:
return None
def update_last_checked(self, item_id: str, last_episode: Optional[int] = None):
"""Update the last_checked timestamp and optionally last episode for an item"""
with Session(engine) as session:
db_item = session.get(WatchlistItemTable, item_id)
if db_item:
db_item.last_checked = datetime.now()
if last_episode is not None:
db_item.last_episode_downloaded = last_episode
session.add(db_item)
session.commit()
item.last_checked = datetime.now()
item.last_episode_downloaded = max(item.last_episode_downloaded, last_episode)
item.updated_at = datetime.now()
self._save_watchlist()
return item
# Alias for backward compatibility
update_check_time = update_last_checked
def get_settings(self) -> WatchlistSettings:
"""Get watchlist settings"""
if not self.settings:
self.settings = WatchlistSettings()
return self.settings
def get_due_items(self) -> List[WatchlistItem]:
"""Get all items that are due for a check based on settings"""
interval = timedelta(hours=self.settings.check_interval_hours)
now = datetime.now()
with Session(engine) as session:
statement = select(WatchlistItemTable).where(
(WatchlistItemTable.status == WatchlistStatus.ACTIVE)
)
db_items = session.exec(statement).all()
due_items = []
for item in db_items:
if not item.last_checked or (item.last_checked + interval) < now:
due_items.append(self._to_api_model(item))
return due_items
def update_settings(self, settings: WatchlistSettings) -> WatchlistSettings:
"""Update watchlist settings"""
"""Update global watchlist settings"""
self.settings = settings
self._save_settings()
logger.info("Updated watchlist settings")
return self.settings
def get_due_for_check(self, check_interval_hours: Optional[int] = None) -> List[WatchlistItem]:
"""Get items that are due for checking"""
if check_interval_hours is None:
check_interval_hours = self.settings.check_interval_hours
cutoff_time = datetime.now() - timedelta(hours=check_interval_hours)
due_items = []
for item in self.watchlist.values():
# Only check active items with auto_download enabled
if item.status != WatchlistStatus.ACTIVE or not item.auto_download:
continue
# Check if due
if item.last_checked is None or item.last_checked < cutoff_time:
due_items.append(item)
logger.info(f"Found {len(due_items)} items due for check")
return due_items
def get_stats(self, user_id: Optional[str] = None) -> Dict:
"""Get watchlist statistics"""
def get_stats(self, user_id: str) -> Dict:
"""Get statistics for a user's watchlist"""
items = self.get_all(user_id=user_id)
stats = {
"total": len(items),
"active": len([i for i in items if i.status == WatchlistStatus.ACTIVE]),
"paused": len([i for i in items if i.status == WatchlistStatus.PAUSED]),
"completed": len([i for i in items if i.status == WatchlistStatus.COMPLETED]),
"auto_download_enabled": len([i for i in items if i.auto_download]),
"total_items": len(items),
"active_items": len([i for i in items if i.status == WatchlistStatus.ACTIVE]),
"paused_items": len([i for i in items if i.status == WatchlistStatus.PAUSED]),
"completed_items": len([i for i in items if i.status == WatchlistStatus.COMPLETED]),
"total_episodes_downloaded": sum(i.last_episode_downloaded for i in items),
"providers": {}
}
# Count by provider
for item in items:
provider = item.provider_id
stats["providers"][provider] = stats["providers"].get(provider, 0) + 1
return stats