Eager loading, merge lookups, bulk sqlalchemy update

This commit is contained in:
itskovacs 2025-10-11 15:52:36 +02:00
parent 11cf3052f3
commit c2058f0a39

View File

@ -1,6 +1,8 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import update
from sqlalchemy.orm import selectinload
from sqlmodel import select from sqlmodel import select
from ..config import settings from ..config import settings
@ -22,13 +24,6 @@ from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image,
router = APIRouter(prefix="/api/trips", tags=["trips"]) router = APIRouter(prefix="/api/trips", tags=["trips"])
def _get_trip_or_404(session, trip_id: int) -> Trip:
trip = session.get(Trip, trip_id)
if not trip:
raise HTTPException(status_code=404, detail="Not found")
return trip
def _trip_from_token_or_404(session, token: str) -> TripShare: def _trip_from_token_or_404(session, token: str) -> TripShare:
share = session.exec(select(TripShare).where(TripShare.token == token)).first() share = session.exec(select(TripShare).where(TripShare.token == token)).first()
if not share: if not share:
@ -42,23 +37,20 @@ def _trip_usernames(session, trip_id: int) -> set[str]:
return {owner} | set(members) return {owner} | set(members)
def _can_access_trip(session, trip_id: int, username: str) -> bool: def _get_verified_trip(session, trip_id: int, username: str) -> Trip:
# TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found # Merge of _verify_trip_member(+_can_access_trip) + _get_trip_or_404
trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() # Returns a Trip if: it exists and username is a TripMember or trip.user (owner)
if trip_owner == username: trip = session.exec(
return True select(Trip)
.outerjoin(TripMember)
is_member = session.exec( .where(
select(TripMember.id) Trip.id == trip_id,
.where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None)) (Trip.user == username) | ((TripMember.user == username) & (TripMember.joined_at.is_not(None))),
.limit(1) )
).first() ).first()
return bool(is_member) if not trip:
def _verify_trip_member(session, trip_id: int, username: str) -> None:
if not _can_access_trip(session, trip_id, username):
raise HTTPException(status_code=404, detail="Not found") raise HTTPException(status_code=404, detail="Not found")
return trip
@router.get("", response_model=list[TripReadBase]) @router.get("", response_model=list[TripReadBase])
@ -120,8 +112,24 @@ def has_pending_invitations(
def read_trip( def read_trip(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
) -> TripRead: ) -> TripRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = session.exec(
_verify_trip_member(session, trip_id, current_user) select(Trip)
.options(
selectinload(Trip.days).selectinload(TripDay.items),
selectinload(Trip.places),
selectinload(Trip.image),
selectinload(Trip.memberships),
)
.outerjoin(TripMember)
.where(
Trip.id == trip_id,
(Trip.user == current_user)
| ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None))),
)
).first()
if not db_trip:
raise HTTPException(status_code=404, detail="Not found")
return TripRead.serialize(db_trip) return TripRead.serialize(db_trip)
@ -156,8 +164,7 @@ def update_trip(
trip: TripUpdate, trip: TripUpdate,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripRead: ) -> TripRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived and (trip.archived is not False): if db_trip.archived and (trip.archived is not False):
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -225,8 +232,7 @@ def update_trip(
def delete_trip( def delete_trip(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
): ):
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -252,14 +258,14 @@ def get_trip_balance(
trip_id: int, trip_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
members = _trip_usernames(session, trip_id) members = _trip_usernames(session, trip_id)
if len(members) < 2: if len(members) < 2:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
trip_items = session.exec( trip_items = session.exec(
select(TripItem) select(TripItem.price, TripItem.paid_by)
.join(TripDay) .join(TripDay)
.where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None)) .where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None))
).all() ).all()
@ -281,8 +287,7 @@ def create_tripday(
session: SessionDep, session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripDayRead: ) -> TripDayRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -303,8 +308,7 @@ def update_tripday(
session: SessionDep, session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripDayRead: ) -> TripDayRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -330,8 +334,7 @@ def delete_tripday(
day_id: int, day_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -353,8 +356,7 @@ def create_tripitem(
session: SessionDep, session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripItemRead: ) -> TripItemRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -415,8 +417,7 @@ def update_tripitem(
session: SessionDep, session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripItemRead: ) -> TripItemRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -506,8 +507,7 @@ def delete_tripitem(
item_id: int, item_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived: if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -540,7 +540,7 @@ def get_shared_trip_url(
trip_id: int, trip_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripShareURL: ) -> TripShareURL:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if not share: if not share:
@ -555,7 +555,7 @@ def create_shared_trip(
trip_id: int, trip_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripShareURL: ) -> TripShareURL:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if shared: if shared:
@ -574,7 +574,7 @@ def delete_shared_trip(
trip_id: int, trip_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if not db_share: if not db_share:
@ -591,7 +591,7 @@ def read_packing_list(
trip_id: int, trip_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripPackingListItemRead]: ) -> list[TripPackingListItemRead]:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id)) p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id))
return [TripPackingListItemRead.serialize(i) for i in p_items] return [TripPackingListItemRead.serialize(i) for i in p_items]
@ -617,7 +617,7 @@ def create_packing_item(
data: TripPackingListItemCreate, data: TripPackingListItemCreate,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripPackingListItemRead: ) -> TripPackingListItemRead:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
item = TripPackingListItem(**data.model_dump(), trip_id=trip_id) item = TripPackingListItem(**data.model_dump(), trip_id=trip_id)
session.add(item) session.add(item)
session.commit() session.commit()
@ -633,7 +633,7 @@ def update_packing_item(
p_id: int, p_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripPackingListItemRead: ) -> TripPackingListItemRead:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
db_item = session.exec( db_item = session.exec(
select(TripPackingListItem).where( select(TripPackingListItem).where(
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
@ -660,7 +660,7 @@ def delete_packing_item(
p_id: int, p_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
item = session.exec( item = session.exec(
select(TripPackingListItem).where( select(TripPackingListItem).where(
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
@ -681,7 +681,7 @@ def read_checklist(
trip_id: int, trip_id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripChecklistItemRead]: ) -> list[TripChecklistItemRead]:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id)) items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id))
return [TripChecklistItemRead.serialize(i) for i in items] return [TripChecklistItemRead.serialize(i) for i in items]
@ -706,7 +706,7 @@ def create_checklist_item(
data: TripChecklistItemCreate, data: TripChecklistItemCreate,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripChecklistItemRead: ) -> TripChecklistItemRead:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
item = TripChecklistItem(**data.model_dump(), trip_id=trip_id) item = TripChecklistItem(**data.model_dump(), trip_id=trip_id)
session.add(item) session.add(item)
session.commit() session.commit()
@ -722,7 +722,7 @@ def update_checklist_item(
id: int, id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripChecklistItemRead: ) -> TripChecklistItemRead:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
db_item = session.exec( db_item = session.exec(
select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id) select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id)
).one_or_none() ).one_or_none()
@ -747,7 +747,7 @@ def delete_checklist_item(
id: int, id: int,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
item = session.exec( item = session.exec(
select(TripChecklistItem).where( select(TripChecklistItem).where(
TripChecklistItem.id == id, TripChecklistItem.id == id,
@ -767,7 +767,7 @@ def delete_checklist_item(
def read_trip_members( def read_trip_members(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
) -> list[TripMemberRead]: ) -> list[TripMemberRead]:
_verify_trip_member(session, trip_id, current_user) _get_verified_trip(session, trip_id, current_user)
members: list[TripMemberRead] = [] members: list[TripMemberRead] = []
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
@ -785,8 +785,7 @@ def invite_trip_member(
data: TripMemberCreate, data: TripMemberCreate,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
) -> TripMemberRead: ) -> TripMemberRead:
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if db_trip.user == data.user: if db_trip.user == data.user:
raise HTTPException(status_code=409, detail="The resource already exists") raise HTTPException(status_code=409, detail="The resource already exists")
@ -820,8 +819,7 @@ def delete_trip_member(
username: str, username: str,
current_user: Annotated[str, Depends(get_current_username)], current_user: Annotated[str, Depends(get_current_username)],
): ):
db_trip = _get_trip_or_404(session, trip_id) db_trip = _get_verified_trip(session, trip_id, current_user)
_verify_trip_member(session, trip_id, current_user)
if current_user == db_trip.user and current_user == username: if current_user == db_trip.user and current_user == username:
raise HTTPException(status_code=400, detail="Bad request") raise HTTPException(status_code=400, detail="Bad request")
@ -840,11 +838,11 @@ def delete_trip_member(
# Set NULL to TripItem.paid_by for this username # Set NULL to TripItem.paid_by for this username
trip_items = session.exec( trip_items = session.exec(
select(TripItem).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username) select(TripItem.id).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
).all() ).all()
for item in trip_items:
item.paid_by = None if trip_items:
session.add_all(trip_items) session.exec(update(TripItem).where(TripItem.id.in_([id for id in trip_items])).values(paid_by=None))
session.delete(member) session.delete(member)
session.commit() session.commit()