"""Playlist service.""" from datetime import datetime from typing import List, Optional from uuid import UUID from sqlalchemy import select, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.playlist import Playlist from app.models.playlist_track import PlaylistTrack from app.models.track import Track class PlaylistService: """Service for playlist operations.""" def __init__(self, db: AsyncSession): self.db = db async def create_playlist( self, user_id: UUID, name: str, description: Optional[str] = None, image_url: Optional[str] = None, is_public: bool = False, is_collaborative: bool = False, ) -> Playlist: """ Create a new playlist. Args: user_id: User UUID name: Playlist name description: Optional description image_url: Optional cover image URL is_public: Whether playlist is public is_collaborative: Whether playlist is collaborative Returns: Created playlist """ playlist = Playlist( user_id=user_id, name=name, description=description, image_url=image_url, is_public=is_public, is_collaborative=is_collaborative, track_count=0, total_duration=0, ) self.db.add(playlist) await self.db.commit() await self.db.refresh(playlist) return playlist async def get_playlist( self, playlist_id: UUID, include_tracks: bool = False, ) -> Optional[Playlist]: """ Get playlist by ID. Args: playlist_id: Playlist UUID include_tracks: Whether to include tracks Returns: Playlist or None """ stmt = select(Playlist).where(Playlist.id == playlist_id) if include_tracks: stmt = stmt.options(selectinload(Playlist.playlist_tracks)) result = await self.db.execute(stmt) return result.scalar_one_or_none() async def get_user_playlists( self, user_id: UUID, limit: int = 50, offset: int = 0, ) -> List[Playlist]: """ Get all playlists for a user. Args: user_id: User UUID limit: Maximum results offset: Pagination offset Returns: List of playlists """ stmt = ( select(Playlist) .where(Playlist.user_id == user_id) .order_by(Playlist.updated_at.desc()) .limit(limit) .offset(offset) ) result = await self.db.execute(stmt) return list(result.scalars().all()) async def update_playlist( self, playlist_id: UUID, user_id: UUID, name: Optional[str] = None, description: Optional[str] = None, image_url: Optional[str] = None, is_public: Optional[bool] = None, ) -> Playlist: """ Update playlist. Args: playlist_id: Playlist UUID user_id: User UUID (for ownership check) name: New name description: New description image_url: New image URL is_public: New public status Returns: Updated playlist Raises: ValueError: If playlist not found or user not owner """ playlist = await self.get_playlist(playlist_id) if not playlist: raise ValueError("Playlist not found") if playlist.user_id != user_id: raise ValueError("Not authorized to update this playlist") if name is not None: playlist.name = name if description is not None: playlist.description = description if image_url is not None: playlist.image_url = image_url if is_public is not None: playlist.is_public = is_public playlist.updated_at = datetime.utcnow() await self.db.commit() await self.db.refresh(playlist) return playlist async def delete_playlist( self, playlist_id: UUID, user_id: UUID, ) -> None: """ Delete a playlist. Args: playlist_id: Playlist UUID user_id: User UUID (for ownership check) Raises: ValueError: If playlist not found or user not owner """ playlist = await self.get_playlist(playlist_id) if not playlist: raise ValueError("Playlist not found") if playlist.user_id != user_id: raise ValueError("Not authorized to delete this playlist") await self.db.delete(playlist) await self.db.commit() async def add_tracks( self, playlist_id: UUID, track_ids: List[UUID], user_id: UUID, position: Optional[int] = None, ) -> Playlist: """ Add tracks to a playlist. Args: playlist_id: Playlist UUID track_ids: List of track UUIDs user_id: User UUID adding the tracks position: Optional starting position Returns: Updated playlist Raises: ValueError: If playlist not found """ playlist = await self.get_playlist(playlist_id) if not playlist: raise ValueError("Playlist not found") # Get current max position stmt = ( select(PlaylistTrack) .where(PlaylistTrack.playlist_id == playlist_id) .order_by(PlaylistTrack.position.desc()) .limit(1) ) result = await self.db.execute(stmt) last_track = result.scalar_one_or_none() max_position = last_track.position if last_track else -1 # Determine starting position if position is None: position = max_position + 1 # Add tracks current_position = position for track_id in track_ids: # Verify track exists track_stmt = select(Track).where(Track.id == track_id) track_result = await self.db.execute(track_stmt) track = track_result.scalar_one_or_none() if not track: continue # Create playlist track playlist_track = PlaylistTrack( playlist_id=playlist_id, track_id=track_id, position=current_position, added_by=user_id, ) self.db.add(playlist_track) current_position += 1 # Update playlist stats playlist.track_count += len(track_ids) playlist.total_duration = await self._calculate_playlist_duration(playlist_id) playlist.updated_at = datetime.utcnow() await self.db.commit() await self.db.refresh(playlist) return playlist async def remove_track( self, playlist_id: UUID, track_id: UUID, user_id: UUID, ) -> Playlist: """ Remove a track from a playlist. Args: playlist_id: Playlist UUID track_id: Track UUID to remove user_id: User UUID (for ownership check) Returns: Updated playlist Raises: ValueError: If playlist or track not found """ playlist = await self.get_playlist(playlist_id) if not playlist: raise ValueError("Playlist not found") if playlist.user_id != user_id: raise ValueError("Not authorized to modify this playlist") # Find and remove the track stmt = select(PlaylistTrack).where( PlaylistTrack.playlist_id == playlist_id, PlaylistTrack.track_id == track_id, ) result = await self.db.execute(stmt) playlist_track = result.scalar_one_or_none() if not playlist_track: raise ValueError("Track not in playlist") # Remove track await self.db.delete(playlist_track) # Reorder remaining tracks tracks_stmt = ( select(PlaylistTrack) .where(PlaylistTrack.playlist_id == playlist_id) .order_by(PlaylistTrack.position) ) tracks_result = await self.db.execute(tracks_stmt) tracks = tracks_result.scalars().all() for index, track in enumerate(tracks): track.position = index # Update playlist stats playlist.track_count -= 1 playlist.total_duration = await self._calculate_playlist_duration(playlist_id) playlist.updated_at = datetime.utcnow() await self.db.commit() await self.db.refresh(playlist) return playlist async def reorder_track( self, playlist_id: UUID, track_id: UUID, new_position: int, user_id: UUID, ) -> Playlist: """ Reorder a track within a playlist. Args: playlist_id: Playlist UUID track_id: Track UUID to reorder new_position: New position (0-indexed) user_id: User UUID (for ownership check) Returns: Updated playlist Raises: ValueError: If playlist or track not found """ playlist = await self.get_playlist(playlist_id) if not playlist: raise ValueError("Playlist not found") if playlist.user_id != user_id: raise ValueError("Not authorized to modify this playlist") # Get all tracks in playlist stmt = ( select(PlaylistTrack) .where(PlaylistTrack.playlist_id == playlist_id) .order_by(PlaylistTrack.position) ) result = await self.db.execute(stmt) tracks = list(result.scalars().all()) # Find the track to move track_to_move = None for track in tracks: if track.track_id == track_id: track_to_move = track break if not track_to_move: raise ValueError("Track not in playlist") # Reorder old_position = track_to_move.position if old_position < new_position: # Moving down: shift tracks between old+1 and new up by 1 for track in tracks: if old_position < track.position <= new_position: track.position -= 1 else: # Moving up: shift tracks between new and old-1 down by 1 for track in tracks: if new_position <= track.position < old_position: track.position += 1 # Set new position track_to_move.position = new_position playlist.updated_at = datetime.utcnow() await self.db.commit() await self.db.refresh(playlist) return playlist async def _calculate_playlist_duration(self, playlist_id: UUID) -> int: """Calculate total duration of a playlist in seconds.""" stmt = ( select(Track) .join(PlaylistTrack, Track.id == PlaylistTrack.track_id) .where(PlaylistTrack.playlist_id == playlist_id) ) result = await self.db.execute(stmt) tracks = result.scalars().all() total_duration = sum( track.duration for track in tracks if track.duration is not None ) return total_duration