From c2058f0a39d0929dd14b72b17eddb14cde09eefd Mon Sep 17 00:00:00 2001 From: itskovacs Date: Sat, 11 Oct 2025 15:52:36 +0200 Subject: [PATCH] :zap: Eager loading, merge lookups, bulk sqlalchemy update --- backend/trip/routers/trips.py | 122 +++++++++++++++++----------------- 1 file changed, 60 insertions(+), 62 deletions(-) diff --git a/backend/trip/routers/trips.py b/backend/trip/routers/trips.py index 5b530cb..0d2b083 100644 --- a/backend/trip/routers/trips.py +++ b/backend/trip/routers/trips.py @@ -1,6 +1,8 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import update +from sqlalchemy.orm import selectinload from sqlmodel import select from ..config import settings @@ -22,13 +24,6 @@ from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image, router = APIRouter(prefix="/api/trips", tags=["trips"]) -def _get_trip_or_404(session, trip_id: int) -> Trip: - trip = session.get(Trip, trip_id) - if not trip: - raise HTTPException(status_code=404, detail="Not found") - return trip - - def _trip_from_token_or_404(session, token: str) -> TripShare: share = session.exec(select(TripShare).where(TripShare.token == token)).first() if not share: @@ -42,23 +37,20 @@ def _trip_usernames(session, trip_id: int) -> set[str]: return {owner} | set(members) -def _can_access_trip(session, trip_id: int, username: str) -> bool: - # TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found - trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() - if trip_owner == username: - return True - - is_member = session.exec( - select(TripMember.id) - .where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None)) - .limit(1) +def _get_verified_trip(session, trip_id: int, username: str) -> Trip: + # Merge of _verify_trip_member(+_can_access_trip) + _get_trip_or_404 + # Returns a Trip if: it exists and username is a TripMember or trip.user (owner) + trip = session.exec( + select(Trip) + .outerjoin(TripMember) + .where( + Trip.id == trip_id, + (Trip.user == username) | ((TripMember.user == username) & (TripMember.joined_at.is_not(None))), + ) ).first() - return bool(is_member) - - -def _verify_trip_member(session, trip_id: int, username: str) -> None: - if not _can_access_trip(session, trip_id, username): + if not trip: raise HTTPException(status_code=404, detail="Not found") + return trip @router.get("", response_model=list[TripReadBase]) @@ -120,8 +112,24 @@ def has_pending_invitations( def read_trip( session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] ) -> TripRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = session.exec( + select(Trip) + .options( + selectinload(Trip.days).selectinload(TripDay.items), + selectinload(Trip.places), + selectinload(Trip.image), + selectinload(Trip.memberships), + ) + .outerjoin(TripMember) + .where( + Trip.id == trip_id, + (Trip.user == current_user) + | ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None))), + ) + ).first() + + if not db_trip: + raise HTTPException(status_code=404, detail="Not found") return TripRead.serialize(db_trip) @@ -156,8 +164,7 @@ def update_trip( trip: TripUpdate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived and (trip.archived is not False): raise HTTPException(status_code=400, detail="Bad request") @@ -225,8 +232,7 @@ def update_trip( def delete_trip( session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] ): - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -252,14 +258,14 @@ def get_trip_balance( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) members = _trip_usernames(session, trip_id) if len(members) < 2: raise HTTPException(status_code=400, detail="Bad request") trip_items = session.exec( - select(TripItem) + select(TripItem.price, TripItem.paid_by) .join(TripDay) .where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None)) ).all() @@ -281,8 +287,7 @@ def create_tripday( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripDayRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -303,8 +308,7 @@ def update_tripday( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripDayRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -330,8 +334,7 @@ def delete_tripday( day_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -353,8 +356,7 @@ def create_tripitem( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripItemRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -415,8 +417,7 @@ def update_tripitem( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripItemRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -506,8 +507,7 @@ def delete_tripitem( item_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -540,7 +540,7 @@ def get_shared_trip_url( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripShareURL: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() if not share: @@ -555,7 +555,7 @@ def create_shared_trip( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripShareURL: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() if shared: @@ -574,7 +574,7 @@ def delete_shared_trip( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() if not db_share: @@ -591,7 +591,7 @@ def read_packing_list( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> list[TripPackingListItemRead]: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id)) return [TripPackingListItemRead.serialize(i) for i in p_items] @@ -617,7 +617,7 @@ def create_packing_item( data: TripPackingListItemCreate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripPackingListItemRead: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) item = TripPackingListItem(**data.model_dump(), trip_id=trip_id) session.add(item) session.commit() @@ -633,7 +633,7 @@ def update_packing_item( p_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripPackingListItemRead: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) db_item = session.exec( select(TripPackingListItem).where( TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id @@ -660,7 +660,7 @@ def delete_packing_item( p_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) item = session.exec( select(TripPackingListItem).where( TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id @@ -681,7 +681,7 @@ def read_checklist( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> list[TripChecklistItemRead]: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id)) return [TripChecklistItemRead.serialize(i) for i in items] @@ -706,7 +706,7 @@ def create_checklist_item( data: TripChecklistItemCreate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripChecklistItemRead: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) item = TripChecklistItem(**data.model_dump(), trip_id=trip_id) session.add(item) session.commit() @@ -722,7 +722,7 @@ def update_checklist_item( id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripChecklistItemRead: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) db_item = session.exec( select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id) ).one_or_none() @@ -747,7 +747,7 @@ def delete_checklist_item( id: int, current_user: Annotated[str, Depends(get_current_username)], ): - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) item = session.exec( select(TripChecklistItem).where( TripChecklistItem.id == id, @@ -767,7 +767,7 @@ def delete_checklist_item( def read_trip_members( session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] ) -> list[TripMemberRead]: - _verify_trip_member(session, trip_id, current_user) + _get_verified_trip(session, trip_id, current_user) members: list[TripMemberRead] = [] owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() @@ -785,8 +785,7 @@ def invite_trip_member( data: TripMemberCreate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripMemberRead: - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if db_trip.user == data.user: raise HTTPException(status_code=409, detail="The resource already exists") @@ -820,8 +819,7 @@ def delete_trip_member( username: str, current_user: Annotated[str, Depends(get_current_username)], ): - db_trip = _get_trip_or_404(session, trip_id) - _verify_trip_member(session, trip_id, current_user) + db_trip = _get_verified_trip(session, trip_id, current_user) if current_user == db_trip.user and current_user == username: raise HTTPException(status_code=400, detail="Bad request") @@ -840,11 +838,11 @@ def delete_trip_member( # Set NULL to TripItem.paid_by for this username trip_items = session.exec( - select(TripItem).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username) + select(TripItem.id).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username) ).all() - for item in trip_items: - item.paid_by = None - session.add_all(trip_items) + + if trip_items: + session.exec(update(TripItem).where(TripItem.id.in_([id for id in trip_items])).values(paid_by=None)) session.delete(member) session.commit()