feat: migrate persistence from JSON to SQLModel (Phase 1)
- Integrated SQLModel with SQLite for robust data persistence - Refactored UserManager and WatchlistManager to use SQL queries - Migrated models to SQLModel with relationships and primary keys - Updated test suite with in-memory database isolation - Removed deprecated JSON storage files
This commit is contained in:
+84
-100
@@ -1,120 +1,118 @@
|
||||
"""User authentication and management system"""
|
||||
"""User authentication and management system with SQLModel support"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, List
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from sqlmodel import Session, select
|
||||
from app.database import engine
|
||||
from app.models.auth import UserTable
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load settings at module level for easier mocking and access
|
||||
settings = get_settings()
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Users database file
|
||||
USERS_DB_FILE = "config/users.json"
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""Manages user storage and authentication"""
|
||||
"""Manages user storage and authentication using SQL database"""
|
||||
|
||||
def __init__(self, db_file: str = USERS_DB_FILE):
|
||||
self.db_file = db_file
|
||||
self.users: Dict[str, dict] = {}
|
||||
self._load_users()
|
||||
def __init__(self):
|
||||
# Database connection is managed via engine and sessions
|
||||
pass
|
||||
|
||||
def _load_users(self):
|
||||
"""Load users from JSON file"""
|
||||
try:
|
||||
if os.path.exists(self.db_file):
|
||||
with open(self.db_file, "r", encoding="utf-8") as f:
|
||||
self.users = json.load(f)
|
||||
logger.info(f"Loaded {len(self.users)} users from database")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading users: {e}")
|
||||
self.users = {}
|
||||
|
||||
def _save_users(self):
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
|
||||
temp_file = f"{self.db_file}.tmp"
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.users, f, indent=2, ensure_ascii=False, default=str)
|
||||
os.replace(temp_file, self.db_file)
|
||||
logger.info(f"Saved {len(self.users)} users to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving users: {e}")
|
||||
|
||||
def get_user(self, username: str) -> Optional[dict]:
|
||||
def get_user(self, username: str) -> Optional[UserTable]:
|
||||
"""Get user by username"""
|
||||
return self.users.get(username)
|
||||
with Session(engine) as session:
|
||||
statement = select(UserTable).where(UserTable.username == username)
|
||||
return session.exec(statement).first()
|
||||
|
||||
def get_user_by_id(self, user_id: str) -> Optional[dict]:
|
||||
def get_user_by_id(self, user_id: str) -> Optional[UserTable]:
|
||||
"""Get user by ID"""
|
||||
for user in self.users.values():
|
||||
if user.get("id") == user_id:
|
||||
return user
|
||||
return None
|
||||
with Session(engine) as session:
|
||||
statement = select(UserTable).where(UserTable.id == user_id)
|
||||
return session.exec(statement).first()
|
||||
|
||||
def create_user(
|
||||
self, username: str, password: str, email: str = None, full_name: str = None
|
||||
) -> dict:
|
||||
) -> UserTable:
|
||||
"""Create a new user"""
|
||||
if username in self.users:
|
||||
raise ValueError(f"Username '{username}' already exists")
|
||||
with Session(engine) as session:
|
||||
# Check if user already exists
|
||||
statement = select(UserTable).where(UserTable.username == username)
|
||||
if session.exec(statement).first():
|
||||
raise ValueError(f"Username '{username}' already exists")
|
||||
|
||||
# Truncate password to 72 bytes if necessary (bcrypt limitation)
|
||||
password_bytes = password.encode("utf-8")
|
||||
if len(password_bytes) > 72:
|
||||
password = password_bytes[:72].decode("utf-8", errors="ignore")
|
||||
# Truncate password to 72 bytes if necessary (bcrypt limitation)
|
||||
password_bytes = password.encode("utf-8")
|
||||
if len(password_bytes) > 72:
|
||||
password = password_bytes[:72].decode("utf-8", errors="ignore")
|
||||
|
||||
# Hash password
|
||||
hashed_password = pwd_context.hash(password)
|
||||
# Hash password
|
||||
hashed_password = pwd_context.hash(password)
|
||||
|
||||
# Create user
|
||||
user = {
|
||||
"id": hashlib.sha256(username.encode()).hexdigest()[:32],
|
||||
"username": username,
|
||||
"email": email,
|
||||
"full_name": full_name,
|
||||
"hashed_password": hashed_password,
|
||||
"is_active": True,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"last_login": None,
|
||||
}
|
||||
# Create user
|
||||
user = UserTable(
|
||||
username=username,
|
||||
email=email,
|
||||
full_name=full_name,
|
||||
hashed_password=hashed_password,
|
||||
is_active=True,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
self.users[username] = user
|
||||
self._save_users()
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
|
||||
logger.info(f"Created user: {username}")
|
||||
return user
|
||||
logger.info(f"Created user: {username}")
|
||||
return user
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[dict]:
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[UserTable]:
|
||||
"""Authenticate user with username and password"""
|
||||
user = self.get_user(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not pwd_context.verify(password, user["hashed_password"]):
|
||||
if not pwd_context.verify(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
user["last_login"] = datetime.now().isoformat()
|
||||
self._save_users()
|
||||
with Session(engine) as session:
|
||||
db_user = session.get(UserTable, user.id)
|
||||
if db_user:
|
||||
db_user.last_login = datetime.now()
|
||||
session.add(db_user)
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
return user
|
||||
|
||||
def update_last_login(self, username: str):
|
||||
"""Update user's last login time"""
|
||||
user = self.get_user(username)
|
||||
if user:
|
||||
user["last_login"] = datetime.now().isoformat()
|
||||
self._save_users()
|
||||
def update_user(self, user_id: str, update_data: dict) -> Optional[UserTable]:
|
||||
"""Update user information"""
|
||||
with Session(engine) as session:
|
||||
db_user = session.get(UserTable, user_id)
|
||||
if not db_user:
|
||||
return None
|
||||
|
||||
for key, value in update_data.items():
|
||||
if hasattr(db_user, key):
|
||||
setattr(db_user, key, value)
|
||||
|
||||
session.add(db_user)
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
|
||||
# Global user manager instance
|
||||
@@ -131,27 +129,11 @@ def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def _get_jwt_config() -> dict:
|
||||
"""Get JWT configuration from settings"""
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
return {
|
||||
"SECRET_KEY": settings.jwt_secret_key,
|
||||
"ALGORITHM": settings.jwt_algorithm,
|
||||
"ACCESS_TOKEN_EXPIRE_MINUTES": settings.access_token_expire_minutes,
|
||||
"REFRESH_TOKEN_EXPIRE_DAYS": settings.refresh_token_expire_days,
|
||||
}
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
|
||||
"""Create JWT access token"""
|
||||
from jose import jwt
|
||||
|
||||
jwt_config = _get_jwt_config()
|
||||
SECRET_KEY = jwt_config["SECRET_KEY"]
|
||||
ALGORITHM = jwt_config["ALGORITHM"]
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = jwt_config["ACCESS_TOKEN_EXPIRE_MINUTES"]
|
||||
SECRET_KEY = settings.jwt_secret_key
|
||||
ALGORITHM = settings.jwt_algorithm
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
|
||||
|
||||
to_encode = data.copy()
|
||||
|
||||
@@ -168,12 +150,10 @@ def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
|
||||
|
||||
def verify_token(token: str) -> Optional[str]:
|
||||
"""Verify JWT token and return username"""
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
jwt_config = _get_jwt_config()
|
||||
SECRET_KEY = jwt_config["SECRET_KEY"]
|
||||
ALGORITHM = jwt_config["ALGORITHM"]
|
||||
SECRET_KEY = settings.jwt_secret_key
|
||||
ALGORITHM = settings.jwt_algorithm
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
@@ -189,7 +169,7 @@ def verify_token(token: str) -> Optional[str]:
|
||||
get_user_from_token = verify_token
|
||||
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials) -> dict:
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials) -> UserTable:
|
||||
"""Get current user from JWT token"""
|
||||
token = credentials.credentials
|
||||
username = verify_token(token)
|
||||
@@ -197,16 +177,19 @@ def get_current_user(credentials: HTTPAuthorizationCredentials) -> dict:
|
||||
user = user_manager.get_user(username)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if not user.get("is_active", True):
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="Inactive user")
|
||||
return user
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
|
||||
|
||||
# Refresh tokens storage
|
||||
REFRESH_TOKENS_FILE = "config/refresh_tokens.json"
|
||||
|
||||
|
||||
def _load_refresh_tokens() -> Dict[str, dict]:
|
||||
"""Load refresh tokens from file"""
|
||||
import json
|
||||
try:
|
||||
if os.path.exists(REFRESH_TOKENS_FILE):
|
||||
with open(REFRESH_TOKENS_FILE, 'r', encoding='utf-8') as f:
|
||||
@@ -218,6 +201,7 @@ def _load_refresh_tokens() -> Dict[str, dict]:
|
||||
|
||||
def _save_refresh_tokens(tokens: Dict[str, dict]):
|
||||
"""Save refresh tokens to file"""
|
||||
import json
|
||||
try:
|
||||
os.makedirs(os.path.dirname(REFRESH_TOKENS_FILE), exist_ok=True)
|
||||
with open(REFRESH_TOKENS_FILE, 'w', encoding='utf-8') as f:
|
||||
|
||||
Reference in New Issue
Block a user