Phase 2 Complete: SQL migration with SQLModel and Alembic
This commit is contained in:
@@ -32,6 +32,7 @@ class UserManager:
|
||||
|
||||
def get_user(self, username: str) -> Optional[UserTable]:
|
||||
"""Get user by username"""
|
||||
from app.models.watchlist import WatchlistItemTable # Force registration
|
||||
with Session(engine) as session:
|
||||
statement = select(UserTable).where(UserTable.username == username)
|
||||
return session.exec(statement).first()
|
||||
@@ -210,6 +211,14 @@ def _save_refresh_tokens(tokens: Dict[str, dict]):
|
||||
logger.error(f"Error saving refresh tokens: {e}")
|
||||
|
||||
|
||||
def _get_jwt_config() -> dict:
|
||||
return {
|
||||
"SECRET_KEY": settings.jwt_secret_key,
|
||||
"ALGORITHM": settings.jwt_algorithm,
|
||||
"ACCESS_TOKEN_EXPIRE_MINUTES": settings.access_token_expire_minutes,
|
||||
"REFRESH_TOKEN_EXPIRE_DAYS": 30
|
||||
}
|
||||
|
||||
def create_access_refresh_tokens(data: dict) -> tuple[str, str]:
|
||||
"""
|
||||
Create both access and refresh tokens.
|
||||
|
||||
+4
-3
@@ -17,10 +17,11 @@ engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
|
||||
def create_db_and_tables():
|
||||
"""Create the database and tables based on the models"""
|
||||
# Import all models here to ensure they are registered with SQLModel.metadata
|
||||
# CRITICAL: Import ALL models here to ensure they are registered with SQLModel.metadata
|
||||
from app.models.auth import UserTable
|
||||
from app.models.watchlist import WatchlistItemTable
|
||||
# Add other models as they are migrated
|
||||
from app.models.watchlist import WatchlistItemTable, WatchlistSettingsTable
|
||||
from app.models.favorites import FavoriteTable
|
||||
from app.models.sonarr import SonarrMappingTable, SonarrConfigTable
|
||||
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
+73
-79
@@ -1,52 +1,24 @@
|
||||
"""
|
||||
Favorites management system for Ohm Stream Downloader
|
||||
Stores user's favorite anime with metadata in a local JSON file
|
||||
Stores user's favorite anime with metadata using SQLModel
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
import aiofiles
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from app.database import engine
|
||||
from app.models.favorites import FavoriteTable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FavoritesManager:
|
||||
"""Manages user's favorite anime list"""
|
||||
"""Manages user's favorite anime list using SQL database"""
|
||||
|
||||
def __init__(self, storage_path: str = "data/favorites.json"):
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._favorites: Dict[str, Dict] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _load(self):
|
||||
"""Load favorites from disk"""
|
||||
async with self._lock:
|
||||
await self._load_for_operation()
|
||||
|
||||
async def _load_for_operation(self):
|
||||
"""Load favorites from disk without acquiring lock (lock must already be held)"""
|
||||
if self.storage_path.exists():
|
||||
try:
|
||||
async with aiofiles.open(self.storage_path, 'r', encoding='utf-8') as f:
|
||||
content = await f.read()
|
||||
self._favorites = json.loads(content) if content.strip() else {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading favorites: {e}")
|
||||
self._favorites = {}
|
||||
else:
|
||||
self._favorites = {}
|
||||
|
||||
async def _save(self):
|
||||
"""Save favorites to disk (assumes lock is already held)"""
|
||||
try:
|
||||
async with aiofiles.open(self.storage_path, 'w', encoding='utf-8') as f:
|
||||
await f.write(json.dumps(self._favorites, indent=2, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving favorites: {e}")
|
||||
def __init__(self, storage_path: str = None):
|
||||
# Database connection is managed via engine and sessions
|
||||
pass
|
||||
|
||||
async def add_favorite(
|
||||
self,
|
||||
@@ -58,48 +30,55 @@ class FavoritesManager:
|
||||
poster_url: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""Add an anime to favorites"""
|
||||
async with self._lock:
|
||||
await self._load_for_operation()
|
||||
with Session(engine) as session:
|
||||
statement = select(FavoriteTable).where(FavoriteTable.anime_id == anime_id)
|
||||
existing = session.exec(statement).first()
|
||||
|
||||
if anime_id in self._favorites:
|
||||
if existing:
|
||||
# Update existing favorite
|
||||
self._favorites[anime_id]["updated_at"] = datetime.now().isoformat()
|
||||
existing.updated_at = datetime.now()
|
||||
if metadata:
|
||||
self._favorites[anime_id]["metadata"] = metadata
|
||||
existing.anime_metadata = metadata
|
||||
if poster_url:
|
||||
self._favorites[anime_id]["poster_url"] = poster_url
|
||||
existing.poster_url = poster_url
|
||||
session.add(existing)
|
||||
session.commit()
|
||||
session.refresh(existing)
|
||||
return self._to_dict(existing)
|
||||
else:
|
||||
# Add new favorite
|
||||
self._favorites[anime_id] = {
|
||||
"id": anime_id,
|
||||
"title": title,
|
||||
"url": url,
|
||||
"provider": provider,
|
||||
"metadata": metadata or {},
|
||||
"poster_url": poster_url,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
await self._save()
|
||||
return self._favorites[anime_id]
|
||||
fav = FavoriteTable(
|
||||
anime_id=anime_id,
|
||||
title=title,
|
||||
url=url,
|
||||
provider=provider,
|
||||
anime_metadata=metadata or {},
|
||||
poster_url=poster_url
|
||||
)
|
||||
session.add(fav)
|
||||
session.commit()
|
||||
session.refresh(fav)
|
||||
return self._to_dict(fav)
|
||||
|
||||
async def remove_favorite(self, anime_id: str) -> bool:
|
||||
"""Remove an anime from favorites"""
|
||||
async with self._lock:
|
||||
await self._load_for_operation()
|
||||
|
||||
if anime_id in self._favorites:
|
||||
del self._favorites[anime_id]
|
||||
await self._save()
|
||||
with Session(engine) as session:
|
||||
statement = select(FavoriteTable).where(FavoriteTable.anime_id == anime_id)
|
||||
existing = session.exec(statement).first()
|
||||
if existing:
|
||||
session.delete(existing)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get_favorite(self, anime_id: str) -> Optional[Dict]:
|
||||
"""Get a specific favorite by ID"""
|
||||
await self._load()
|
||||
return self._favorites.get(anime_id)
|
||||
with Session(engine) as session:
|
||||
statement = select(FavoriteTable).where(FavoriteTable.anime_id == anime_id)
|
||||
existing = session.exec(statement).first()
|
||||
if existing:
|
||||
return self._to_dict(existing)
|
||||
return None
|
||||
|
||||
async def list_favorites(
|
||||
self,
|
||||
@@ -109,13 +88,15 @@ class FavoritesManager:
|
||||
filter_genre: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""List all favorites with optional sorting and filtering"""
|
||||
await self._load()
|
||||
|
||||
favorites = list(self._favorites.values())
|
||||
|
||||
# Apply filters
|
||||
if filter_provider:
|
||||
favorites = [f for f in favorites if f["provider"] == filter_provider]
|
||||
with Session(engine) as session:
|
||||
statement = select(FavoriteTable)
|
||||
|
||||
if filter_provider:
|
||||
statement = statement.where(FavoriteTable.provider == filter_provider)
|
||||
|
||||
# SQLite JSON filtering for genres is complex, handle it in Python
|
||||
results = session.exec(statement).all()
|
||||
favorites = [self._to_dict(fav) for fav in results]
|
||||
|
||||
if filter_genre:
|
||||
favorites = [
|
||||
@@ -144,8 +125,9 @@ class FavoritesManager:
|
||||
|
||||
async def is_favorite(self, anime_id: str) -> bool:
|
||||
"""Check if an anime is in favorites"""
|
||||
await self._load()
|
||||
return anime_id in self._favorites
|
||||
with Session(engine) as session:
|
||||
statement = select(FavoriteTable).where(FavoriteTable.anime_id == anime_id)
|
||||
return session.exec(statement).first() is not None
|
||||
|
||||
async def toggle_favorite(
|
||||
self,
|
||||
@@ -168,19 +150,18 @@ class FavoritesManager:
|
||||
|
||||
async def get_stats(self) -> Dict:
|
||||
"""Get statistics about favorites"""
|
||||
await self._load()
|
||||
|
||||
total = len(self._favorites)
|
||||
favorites = await self.list_favorites()
|
||||
total = len(favorites)
|
||||
|
||||
# Count by provider
|
||||
by_provider = {}
|
||||
for fav in self._favorites.values():
|
||||
for fav in favorites:
|
||||
provider = fav["provider"]
|
||||
by_provider[provider] = by_provider.get(provider, 0) + 1
|
||||
|
||||
# Count by genre
|
||||
by_genre = {}
|
||||
for fav in self._favorites.values():
|
||||
for fav in favorites:
|
||||
for genre in fav.get("metadata", {}).get("genres", []):
|
||||
by_genre[genre] = by_genre.get(genre, 0) + 1
|
||||
|
||||
@@ -190,6 +171,19 @@ class FavoritesManager:
|
||||
"by_genre": by_genre
|
||||
}
|
||||
|
||||
def _to_dict(self, fav: FavoriteTable) -> Dict:
|
||||
"""Convert a FavoriteTable instance to a dictionary for API compatibility"""
|
||||
return {
|
||||
"id": fav.anime_id,
|
||||
"title": fav.title,
|
||||
"url": fav.url,
|
||||
"provider": fav.provider,
|
||||
"metadata": fav.anime_metadata,
|
||||
"poster_url": fav.poster_url,
|
||||
"created_at": fav.created_at.isoformat() if fav.created_at else None,
|
||||
"updated_at": fav.updated_at.isoformat() if fav.updated_at else None
|
||||
}
|
||||
|
||||
|
||||
# Global favorites manager instance
|
||||
_favorites_manager: Optional[FavoritesManager] = None
|
||||
|
||||
@@ -63,3 +63,9 @@ class AnimeSearchResult(BaseModel):
|
||||
cover_image: Optional[str] = None
|
||||
type: str # "search_result" or "direct"
|
||||
metadata: Optional[AnimeMetadata] = None
|
||||
|
||||
# Import all SQLModel tables here to ensure they are registered together
|
||||
from .auth import UserTable
|
||||
from .watchlist import WatchlistItemTable, WatchlistSettingsTable
|
||||
from .favorites import FavoriteTable
|
||||
from .sonarr import SonarrMappingTable, SonarrConfigTable
|
||||
|
||||
+4
-1
@@ -28,7 +28,7 @@ class UserTable(UserBase, table=True):
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
# Relationships
|
||||
# Relationships - Using string reference to avoid circular import errors
|
||||
watchlist_items: List["WatchlistItemTable"] = Relationship(back_populates="user")
|
||||
|
||||
|
||||
@@ -60,3 +60,6 @@ class Token(BaseModel):
|
||||
class UserInDB(User):
|
||||
"""Schema for user stored in database (with hashed password)"""
|
||||
hashed_password: str
|
||||
|
||||
# Import WatchlistItemTable here to resolve SQLModel Relationship mappings
|
||||
from .watchlist import WatchlistItemTable
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Models for Favorites system with SQLModel support"""
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional, Dict, List
|
||||
from datetime import datetime
|
||||
from sqlmodel import SQLModel, Field, Column, String
|
||||
|
||||
class FavoriteBase(SQLModel):
|
||||
"""Base schema for favorite anime"""
|
||||
anime_id: str = Field(index=True)
|
||||
title: str = Field(index=True)
|
||||
url: str
|
||||
provider: str
|
||||
poster_url: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
class FavoriteTable(FavoriteBase, table=True):
|
||||
"""Database table for favorites"""
|
||||
__tablename__ = "favorites"
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
primary_key=True,
|
||||
index=True,
|
||||
nullable=False
|
||||
)
|
||||
user_id: str = Field(foreign_key="users.id", index=True, default="default")
|
||||
|
||||
# Store metadata dictionary as JSON string in SQLite
|
||||
metadata_json: Optional[str] = Field(default="{}", sa_column=Column(String))
|
||||
|
||||
@property
|
||||
def anime_metadata(self) -> Dict:
|
||||
try:
|
||||
return json.loads(self.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
@anime_metadata.setter
|
||||
def anime_metadata(self, value: Dict):
|
||||
self.metadata_json = json.dumps(value or {})
|
||||
+55
-6
@@ -1,8 +1,10 @@
|
||||
"""Pydantic models for Sonarr webhook integration"""
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field as PydanticField, validator
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from sqlmodel import SQLModel, Field
|
||||
import uuid
|
||||
|
||||
|
||||
class SonarrEventType(str, Enum):
|
||||
@@ -45,7 +47,7 @@ class SonarrEpisodeFile(BaseModel):
|
||||
|
||||
class SonarrSeries(BaseModel):
|
||||
"""Series information from Sonarr"""
|
||||
tvdbId: int = Field(..., alias="tvdbId")
|
||||
tvdbId: int = PydanticField(..., alias="tvdbId")
|
||||
title: str
|
||||
sortTitle: str
|
||||
status: str
|
||||
@@ -129,8 +131,33 @@ class SonarrWebhookPayload(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class SonarrMappingBase(SQLModel):
|
||||
sonarr_series_id: int = Field(index=True, unique=True)
|
||||
sonarr_title: str
|
||||
anime_provider: str
|
||||
anime_url: str
|
||||
anime_title: str
|
||||
lang: str = Field(default="vostfr")
|
||||
quality_preference: Optional[str] = None
|
||||
auto_download: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SonarrMappingTable(SonarrMappingBase, table=True):
|
||||
"""Database table for Sonarr mappings"""
|
||||
__tablename__ = "sonarr_mappings"
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
primary_key=True,
|
||||
index=True,
|
||||
nullable=False
|
||||
)
|
||||
user_id: str = Field(foreign_key="users.id", index=True, default="default")
|
||||
|
||||
|
||||
class SonarrMapping(BaseModel):
|
||||
"""Mapping between Sonarr series and anime providers"""
|
||||
"""Mapping between Sonarr series and anime providers (API model)"""
|
||||
sonarr_series_id: int
|
||||
sonarr_title: str
|
||||
anime_provider: str # 'anime-sama', 'neko-sama', etc.
|
||||
@@ -139,8 +166,8 @@ class SonarrMapping(BaseModel):
|
||||
lang: str = "vostfr"
|
||||
quality_preference: Optional[str] = None # '1080p', '720p', etc.
|
||||
auto_download: bool = True
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
created_at: datetime = PydanticField(default_factory=datetime.now)
|
||||
updated_at: datetime = PydanticField(default_factory=datetime.now)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
@@ -148,8 +175,30 @@ class SonarrMapping(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class SonarrConfigBase(SQLModel):
|
||||
webhook_enabled: bool = Field(default=False)
|
||||
webhook_secret: Optional[str] = None
|
||||
auto_download_enabled: bool = Field(default=True)
|
||||
default_language: str = Field(default="vostfr")
|
||||
default_quality: Optional[str] = None
|
||||
default_provider: str = Field(default="anime-sama")
|
||||
verify_hmac: bool = Field(default=False)
|
||||
log_webhooks: bool = Field(default=True)
|
||||
|
||||
|
||||
class SonarrConfigTable(SonarrConfigBase, table=True):
|
||||
"""Database table for Sonarr configuration (singleton)"""
|
||||
__tablename__ = "sonarr_config"
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
primary_key=True,
|
||||
index=True,
|
||||
nullable=False
|
||||
)
|
||||
|
||||
|
||||
class SonarrConfig(BaseModel):
|
||||
"""Sonarr webhook configuration"""
|
||||
"""Sonarr webhook configuration (API Model)"""
|
||||
webhook_enabled: bool = False
|
||||
webhook_secret: Optional[str] = None # HMAC SHA256 secret
|
||||
auto_download_enabled: bool = True
|
||||
|
||||
+22
-1
@@ -74,7 +74,7 @@ class WatchlistItemTable(WatchlistItemBase, table=True):
|
||||
def genres(self, value: List[str]):
|
||||
self.genres_json = json.dumps(value or [])
|
||||
|
||||
# Relationships
|
||||
# Relationships - Using string reference
|
||||
user: Optional["UserTable"] = Relationship(back_populates="watchlist_items")
|
||||
|
||||
|
||||
@@ -148,6 +148,24 @@ class AutoDownloadResult(BaseModel):
|
||||
checked_at: datetime = PydanticField(default_factory=datetime.now)
|
||||
|
||||
|
||||
class WatchlistSettingsBase(SQLModel):
|
||||
check_interval_hours: int = Field(default=6)
|
||||
auto_download_enabled: bool = Field(default=True)
|
||||
max_concurrent_auto_downloads: int = Field(default=2)
|
||||
notify_on_new_episodes: bool = Field(default=False)
|
||||
include_completed_anime: bool = Field(default=False)
|
||||
|
||||
class WatchlistSettingsTable(WatchlistSettingsBase, table=True):
|
||||
"""Database table for global watchlist settings"""
|
||||
__tablename__ = "watchlist_settings"
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
primary_key=True,
|
||||
index=True,
|
||||
nullable=False
|
||||
)
|
||||
user_id: str = Field(foreign_key="users.id", index=True, default="default")
|
||||
|
||||
class WatchlistSettings(BaseModel):
|
||||
"""Global watchlist settings"""
|
||||
check_interval_hours: int = PydanticField(default=6, ge=1, le=168)
|
||||
@@ -155,3 +173,6 @@ class WatchlistSettings(BaseModel):
|
||||
max_concurrent_auto_downloads: int = PydanticField(default=2, ge=1, le=10)
|
||||
notify_on_new_episodes: bool = PydanticField(default=False)
|
||||
include_completed_anime: bool = PydanticField(default=False)
|
||||
|
||||
# Import UserTable here to resolve SQLModel Relationship mappings
|
||||
from .auth import UserTable
|
||||
|
||||
+145
-113
@@ -1,18 +1,19 @@
|
||||
"""Sonarr webhook handler and integration logic"""
|
||||
"""Sonarr webhook handler and integration logic using SQLModel"""
|
||||
import hmac
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, List, Tuple, Any
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from app.database import engine
|
||||
from app.models.sonarr import (
|
||||
SonarrWebhookPayload,
|
||||
SonarrEventType,
|
||||
SonarrMapping,
|
||||
SonarrMappingTable,
|
||||
SonarrConfig,
|
||||
SonarrConfigTable,
|
||||
SonarrDownloadRequest
|
||||
)
|
||||
from app.models import DownloadRequest
|
||||
@@ -23,69 +24,150 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SonarrHandler:
|
||||
"""Handles Sonarr webhooks and manages series mappings"""
|
||||
"""Handles Sonarr webhooks and manages series mappings using SQL database"""
|
||||
|
||||
def __init__(self, config_path: str = "config/sonarr.json", mappings_path: str = "config/sonarr_mappings.json"):
|
||||
self.config_path = Path(config_path)
|
||||
self.mappings_path = Path(mappings_path)
|
||||
self.config = self._load_config()
|
||||
self.mappings = self._load_mappings()
|
||||
def __init__(self, config_path: str = None, mappings_path: str = None):
|
||||
self.download_manager = None
|
||||
|
||||
# Create config directories if they don't exist
|
||||
self.config_path.parent.mkdir(exist_ok=True)
|
||||
self.mappings_path.parent.mkdir(exist_ok=True)
|
||||
self._ensure_default_config()
|
||||
|
||||
def set_download_manager(self, download_manager):
|
||||
self.download_manager = download_manager
|
||||
|
||||
def _load_config(self) -> SonarrConfig:
|
||||
"""Load Sonarr configuration from file"""
|
||||
if self.config_path.exists():
|
||||
try:
|
||||
with open(self.config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
return SonarrConfig(**data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load Sonarr config: {e}")
|
||||
return SonarrConfig()
|
||||
def _ensure_default_config(self):
|
||||
"""Ensure a default config exists in the database"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrConfigTable)
|
||||
if not session.exec(statement).first():
|
||||
session.add(SonarrConfigTable())
|
||||
session.commit()
|
||||
|
||||
def _save_config(self):
|
||||
try:
|
||||
temp_file = f"{self.config_path}.tmp"
|
||||
with open(temp_file, 'w') as f:
|
||||
json.dump(self.config.model_dump(mode='json'), f, indent=2)
|
||||
os.replace(temp_file, self.config_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Sonarr config: {e}")
|
||||
raise
|
||||
def get_config(self) -> SonarrConfig:
|
||||
"""Get current configuration"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrConfigTable)
|
||||
db_config = session.exec(statement).first()
|
||||
if db_config:
|
||||
return SonarrConfig(
|
||||
webhook_enabled=db_config.webhook_enabled,
|
||||
webhook_secret=db_config.webhook_secret,
|
||||
auto_download_enabled=db_config.auto_download_enabled,
|
||||
default_language=db_config.default_language,
|
||||
default_quality=db_config.default_quality,
|
||||
default_provider=db_config.default_provider,
|
||||
verify_hmac=db_config.verify_hmac,
|
||||
log_webhooks=db_config.log_webhooks
|
||||
)
|
||||
return SonarrConfig()
|
||||
|
||||
def _load_mappings(self) -> List[SonarrMapping]:
|
||||
"""Load Sonarr to anime mappings from file"""
|
||||
if self.mappings_path.exists():
|
||||
try:
|
||||
with open(self.mappings_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
return [SonarrMapping(**item) for item in data]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load Sonarr mappings: {e}")
|
||||
return []
|
||||
def update_config(self, config: SonarrConfig) -> SonarrConfig:
|
||||
"""Update configuration"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrConfigTable)
|
||||
db_config = session.exec(statement).first()
|
||||
|
||||
if not db_config:
|
||||
db_config = SonarrConfigTable()
|
||||
|
||||
db_config.webhook_enabled = config.webhook_enabled
|
||||
db_config.webhook_secret = config.webhook_secret
|
||||
db_config.auto_download_enabled = config.auto_download_enabled
|
||||
db_config.default_language = config.default_language
|
||||
db_config.default_quality = config.default_quality
|
||||
db_config.default_provider = config.default_provider
|
||||
db_config.verify_hmac = config.verify_hmac
|
||||
db_config.log_webhooks = config.log_webhooks
|
||||
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
|
||||
logger.info("Sonarr configuration updated in database")
|
||||
return config
|
||||
|
||||
def _save_mappings(self):
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.mappings_path), exist_ok=True)
|
||||
temp_file = f"{self.mappings_path}.tmp"
|
||||
with open(temp_file, 'w') as f:
|
||||
mappings_data = [m.model_dump(mode='json') for m in self.mappings]
|
||||
json.dump(mappings_data, f, indent=2)
|
||||
os.replace(temp_file, self.mappings_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save mappings: {e}")
|
||||
raise
|
||||
def _to_pydantic(self, db_mapping: SonarrMappingTable) -> SonarrMapping:
|
||||
return SonarrMapping(
|
||||
sonarr_series_id=db_mapping.sonarr_series_id,
|
||||
sonarr_title=db_mapping.sonarr_title,
|
||||
anime_provider=db_mapping.anime_provider,
|
||||
anime_url=db_mapping.anime_url,
|
||||
anime_title=db_mapping.anime_title,
|
||||
lang=db_mapping.lang,
|
||||
quality_preference=db_mapping.quality_preference,
|
||||
auto_download=db_mapping.auto_download,
|
||||
created_at=db_mapping.created_at,
|
||||
updated_at=db_mapping.updated_at
|
||||
)
|
||||
|
||||
def get_mappings(self) -> List[SonarrMapping]:
|
||||
"""Get all mappings"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrMappingTable)
|
||||
db_mappings = session.exec(statement).all()
|
||||
return [self._to_pydantic(m) for m in db_mappings]
|
||||
|
||||
def get_mapping(self, sonarr_series_id: int) -> Optional[SonarrMapping]:
|
||||
"""Get mapping for specific series"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrMappingTable).where(SonarrMappingTable.sonarr_series_id == sonarr_series_id)
|
||||
db_mapping = session.exec(statement).first()
|
||||
if db_mapping:
|
||||
return self._to_pydantic(db_mapping)
|
||||
return None
|
||||
|
||||
def add_mapping(self, mapping: SonarrMapping) -> SonarrMapping:
|
||||
"""Add or update a mapping"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrMappingTable).where(SonarrMappingTable.sonarr_series_id == mapping.sonarr_series_id)
|
||||
db_mapping = session.exec(statement).first()
|
||||
|
||||
if db_mapping:
|
||||
# Update existing
|
||||
db_mapping.sonarr_title = mapping.sonarr_title
|
||||
db_mapping.anime_provider = mapping.anime_provider
|
||||
db_mapping.anime_url = mapping.anime_url
|
||||
db_mapping.anime_title = mapping.anime_title
|
||||
db_mapping.lang = mapping.lang
|
||||
db_mapping.quality_preference = mapping.quality_preference
|
||||
db_mapping.auto_download = mapping.auto_download
|
||||
db_mapping.updated_at = datetime.now()
|
||||
logger.info(f"Updated mapping for series {mapping.sonarr_title}")
|
||||
else:
|
||||
# Create new
|
||||
db_mapping = SonarrMappingTable(
|
||||
user_id="default",
|
||||
sonarr_series_id=mapping.sonarr_series_id,
|
||||
sonarr_title=mapping.sonarr_title,
|
||||
anime_provider=mapping.anime_provider,
|
||||
anime_url=mapping.anime_url,
|
||||
anime_title=mapping.anime_title,
|
||||
lang=mapping.lang,
|
||||
quality_preference=mapping.quality_preference,
|
||||
auto_download=mapping.auto_download,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
logger.info(f"Added mapping for series {mapping.sonarr_title}")
|
||||
|
||||
session.add(db_mapping)
|
||||
session.commit()
|
||||
session.refresh(db_mapping)
|
||||
return self._to_pydantic(db_mapping)
|
||||
|
||||
def delete_mapping(self, sonarr_series_id: int) -> bool:
|
||||
"""Delete a mapping"""
|
||||
with Session(engine) as session:
|
||||
statement = select(SonarrMappingTable).where(SonarrMappingTable.sonarr_series_id == sonarr_series_id)
|
||||
db_mapping = session.exec(statement).first()
|
||||
if db_mapping:
|
||||
session.delete(db_mapping)
|
||||
session.commit()
|
||||
logger.info(f"Deleted mapping for series ID {sonarr_series_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def verify_hmac(self, payload: bytes, signature: str) -> bool:
|
||||
"""Verify HMAC SHA256 signature"""
|
||||
if not self.config.verify_hmac or not self.config.webhook_secret:
|
||||
config = self.get_config()
|
||||
if not config.verify_hmac or not config.webhook_secret:
|
||||
return True
|
||||
|
||||
try:
|
||||
@@ -94,7 +176,7 @@ class SonarrHandler:
|
||||
signature = signature[7:]
|
||||
|
||||
computed_hmac = hmac.new(
|
||||
self.config.webhook_secret.encode(),
|
||||
config.webhook_secret.encode(),
|
||||
payload,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
@@ -104,57 +186,6 @@ class SonarrHandler:
|
||||
logger.error(f"HMAC verification failed: {e}")
|
||||
return False
|
||||
|
||||
def get_config(self) -> SonarrConfig:
|
||||
"""Get current configuration"""
|
||||
return self.config
|
||||
|
||||
def update_config(self, config: SonarrConfig) -> SonarrConfig:
|
||||
"""Update configuration"""
|
||||
self.config = config
|
||||
self._save_config()
|
||||
logger.info("Sonarr configuration updated")
|
||||
return self.config
|
||||
|
||||
def get_mappings(self) -> List[SonarrMapping]:
|
||||
"""Get all mappings"""
|
||||
return self.mappings
|
||||
|
||||
def get_mapping(self, sonarr_series_id: int) -> Optional[SonarrMapping]:
|
||||
"""Get mapping for specific series"""
|
||||
for mapping in self.mappings:
|
||||
if mapping.sonarr_series_id == sonarr_series_id:
|
||||
return mapping
|
||||
return None
|
||||
|
||||
def add_mapping(self, mapping: SonarrMapping) -> SonarrMapping:
|
||||
"""Add or update a mapping"""
|
||||
# Check if mapping already exists
|
||||
for i, existing in enumerate(self.mappings):
|
||||
if existing.sonarr_series_id == mapping.sonarr_series_id:
|
||||
mapping.updated_at = datetime.now()
|
||||
self.mappings[i] = mapping
|
||||
self._save_mappings()
|
||||
logger.info(f"Updated mapping for series {mapping.sonarr_title}")
|
||||
return mapping
|
||||
|
||||
# Add new mapping
|
||||
mapping.created_at = datetime.now()
|
||||
mapping.updated_at = datetime.now()
|
||||
self.mappings.append(mapping)
|
||||
self._save_mappings()
|
||||
logger.info(f"Added mapping for series {mapping.sonarr_title}")
|
||||
return mapping
|
||||
|
||||
def delete_mapping(self, sonarr_series_id: int) -> bool:
|
||||
"""Delete a mapping"""
|
||||
for i, mapping in enumerate(self.mappings):
|
||||
if mapping.sonarr_series_id == sonarr_series_id:
|
||||
del self.mappings[i]
|
||||
self._save_mappings()
|
||||
logger.info(f"Deleted mapping for series ID {sonarr_series_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def search_anime_by_title(self, title: str, provider: str = "anime-sama", lang: str = "vostfr") -> List[Dict]:
|
||||
"""Search for anime by title using specified provider"""
|
||||
try:
|
||||
@@ -197,15 +228,16 @@ class SonarrHandler:
|
||||
|
||||
async def process_webhook(self, payload: SonarrWebhookPayload) -> Dict[str, Any]:
|
||||
"""Process Sonarr webhook payload"""
|
||||
if not self.config.webhook_enabled:
|
||||
config = self.get_config()
|
||||
if not config.webhook_enabled:
|
||||
return {"status": "ignored", "reason": "Webhook not enabled"}
|
||||
|
||||
if self.config.log_webhooks:
|
||||
if config.log_webhooks:
|
||||
logger.info(f"Received Sonarr webhook: {payload.eventType.value}")
|
||||
|
||||
# Handle different event types
|
||||
if payload.eventType == SonarrEventType.GRAB:
|
||||
return await self._handle_grab(payload)
|
||||
return await self._handle_grab(payload, config)
|
||||
elif payload.eventType == SonarrEventType.DOWNLOAD:
|
||||
return await self._handle_download(payload)
|
||||
elif payload.eventType == SonarrEventType.RENAME:
|
||||
@@ -217,9 +249,9 @@ class SonarrHandler:
|
||||
else:
|
||||
return {"status": "ignored", "reason": f"Unhandled event type: {payload.eventType}"}
|
||||
|
||||
async def _handle_grab(self, payload: SonarrWebhookPayload) -> Dict:
|
||||
async def _handle_grab(self, payload: SonarrWebhookPayload, config: SonarrConfig) -> Dict:
|
||||
"""Handle Grab event (when Sonarr downloads a release)"""
|
||||
if not self.config.auto_download_enabled:
|
||||
if not config.auto_download_enabled:
|
||||
return {"status": "ignored", "reason": "Auto-download disabled"}
|
||||
|
||||
if not payload.series or not payload.episodes:
|
||||
|
||||
+43
-23
@@ -16,50 +16,70 @@ from app.models.watchlist import (
|
||||
WatchlistItemUpdate,
|
||||
WatchlistStatus,
|
||||
WatchlistSettings,
|
||||
WatchlistSettingsTable,
|
||||
NewEpisodeInfo,
|
||||
AutoDownloadResult
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Settings file remains JSON for simplicity for now
|
||||
WATCHLIST_SETTINGS_FILE = "config/watchlist_settings.json"
|
||||
|
||||
|
||||
class WatchlistManager:
|
||||
"""Manages user watchlist for automatic episode downloads using SQL database"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings_file = WATCHLIST_SETTINGS_FILE
|
||||
self.settings: Optional[WatchlistSettings] = None
|
||||
self._load_settings()
|
||||
|
||||
def _load_settings(self):
|
||||
"""Load watchlist settings from JSON file"""
|
||||
"""Load watchlist settings from database"""
|
||||
try:
|
||||
if os.path.exists(self.settings_file):
|
||||
with open(self.settings_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self.settings = WatchlistSettings(**data)
|
||||
logger.info(f"Loaded watchlist settings")
|
||||
else:
|
||||
self.settings = WatchlistSettings()
|
||||
self._save_settings()
|
||||
logger.info("Settings file not found, using defaults")
|
||||
with Session(engine) as session:
|
||||
statement = select(WatchlistSettingsTable).where(WatchlistSettingsTable.user_id == "default")
|
||||
db_settings = session.exec(statement).first()
|
||||
if db_settings:
|
||||
self.settings = WatchlistSettings(
|
||||
check_interval_hours=db_settings.check_interval_hours,
|
||||
auto_download_enabled=db_settings.auto_download_enabled,
|
||||
max_concurrent_auto_downloads=db_settings.max_concurrent_auto_downloads,
|
||||
notify_on_new_episodes=db_settings.notify_on_new_episodes,
|
||||
include_completed_anime=db_settings.include_completed_anime
|
||||
)
|
||||
logger.info(f"Loaded watchlist settings from database")
|
||||
else:
|
||||
self.settings = WatchlistSettings()
|
||||
self._save_settings()
|
||||
logger.info("Settings not found in database, created defaults")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading settings: {e}")
|
||||
logger.error(f"Error loading settings from database: {e}")
|
||||
self.settings = WatchlistSettings()
|
||||
|
||||
def _save_settings(self):
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.settings_file), exist_ok=True)
|
||||
temp_file = f"{self.settings_file}.tmp"
|
||||
with open(temp_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.settings.model_dump(mode='json'), f, indent=2, ensure_ascii=False)
|
||||
os.replace(temp_file, self.settings_file)
|
||||
logger.debug("Saved watchlist settings")
|
||||
with Session(engine) as session:
|
||||
statement = select(WatchlistSettingsTable).where(WatchlistSettingsTable.user_id == "default")
|
||||
db_settings = session.exec(statement).first()
|
||||
|
||||
if db_settings:
|
||||
db_settings.check_interval_hours = self.settings.check_interval_hours
|
||||
db_settings.auto_download_enabled = self.settings.auto_download_enabled
|
||||
db_settings.max_concurrent_auto_downloads = self.settings.max_concurrent_auto_downloads
|
||||
db_settings.notify_on_new_episodes = self.settings.notify_on_new_episodes
|
||||
db_settings.include_completed_anime = self.settings.include_completed_anime
|
||||
else:
|
||||
db_settings = WatchlistSettingsTable(
|
||||
user_id="default",
|
||||
check_interval_hours=self.settings.check_interval_hours,
|
||||
auto_download_enabled=self.settings.auto_download_enabled,
|
||||
max_concurrent_auto_downloads=self.settings.max_concurrent_auto_downloads,
|
||||
notify_on_new_episodes=self.settings.notify_on_new_episodes,
|
||||
include_completed_anime=self.settings.include_completed_anime
|
||||
)
|
||||
session.add(db_settings)
|
||||
|
||||
session.commit()
|
||||
logger.debug("Saved watchlist settings to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving settings: {e}")
|
||||
logger.error(f"Error saving settings to database: {e}")
|
||||
|
||||
def _to_api_model(self, db_item: WatchlistItemTable) -> WatchlistItem:
|
||||
"""Convert database table model to API response model"""
|
||||
|
||||
Reference in New Issue
Block a user