Eager loading, merge lookups, bulk sqlalchemy update

This commit is contained in:
itskovacs 2025-10-11 15:52:36 +02:00
parent 11cf3052f3
commit c2058f0a39

View File

@ -1,6 +1,8 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import update
from sqlalchemy.orm import selectinload
from sqlmodel import select
from ..config import settings
@ -22,13 +24,6 @@ from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image,
router = APIRouter(prefix="/api/trips", tags=["trips"])
def _get_trip_or_404(session, trip_id: int) -> Trip:
trip = session.get(Trip, trip_id)
if not trip:
raise HTTPException(status_code=404, detail="Not found")
return trip
def _trip_from_token_or_404(session, token: str) -> TripShare:
share = session.exec(select(TripShare).where(TripShare.token == token)).first()
if not share:
@ -42,23 +37,20 @@ def _trip_usernames(session, trip_id: int) -> set[str]:
return {owner} | set(members)
def _can_access_trip(session, trip_id: int, username: str) -> bool:
# TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found
trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
if trip_owner == username:
return True
is_member = session.exec(
select(TripMember.id)
.where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None))
.limit(1)
def _get_verified_trip(session, trip_id: int, username: str) -> Trip:
# Merge of _verify_trip_member(+_can_access_trip) + _get_trip_or_404
# Returns a Trip if: it exists and username is a TripMember or trip.user (owner)
trip = session.exec(
select(Trip)
.outerjoin(TripMember)
.where(
Trip.id == trip_id,
(Trip.user == username) | ((TripMember.user == username) & (TripMember.joined_at.is_not(None))),
)
).first()
return bool(is_member)
def _verify_trip_member(session, trip_id: int, username: str) -> None:
if not _can_access_trip(session, trip_id, username):
if not trip:
raise HTTPException(status_code=404, detail="Not found")
return trip
@router.get("", response_model=list[TripReadBase])
@ -120,8 +112,24 @@ def has_pending_invitations(
def read_trip(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
) -> TripRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = session.exec(
select(Trip)
.options(
selectinload(Trip.days).selectinload(TripDay.items),
selectinload(Trip.places),
selectinload(Trip.image),
selectinload(Trip.memberships),
)
.outerjoin(TripMember)
.where(
Trip.id == trip_id,
(Trip.user == current_user)
| ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None))),
)
).first()
if not db_trip:
raise HTTPException(status_code=404, detail="Not found")
return TripRead.serialize(db_trip)
@ -156,8 +164,7 @@ def update_trip(
trip: TripUpdate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived and (trip.archived is not False):
raise HTTPException(status_code=400, detail="Bad request")
@ -225,8 +232,7 @@ def update_trip(
def delete_trip(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -252,14 +258,14 @@ def get_trip_balance(
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
members = _trip_usernames(session, trip_id)
if len(members) < 2:
raise HTTPException(status_code=400, detail="Bad request")
trip_items = session.exec(
select(TripItem)
select(TripItem.price, TripItem.paid_by)
.join(TripDay)
.where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None))
).all()
@ -281,8 +287,7 @@ def create_tripday(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripDayRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -303,8 +308,7 @@ def update_tripday(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripDayRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -330,8 +334,7 @@ def delete_tripday(
day_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -353,8 +356,7 @@ def create_tripitem(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripItemRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -415,8 +417,7 @@ def update_tripitem(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripItemRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -506,8 +507,7 @@ def delete_tripitem(
item_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
@ -540,7 +540,7 @@ def get_shared_trip_url(
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripShareURL:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if not share:
@ -555,7 +555,7 @@ def create_shared_trip(
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripShareURL:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if shared:
@ -574,7 +574,7 @@ def delete_shared_trip(
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if not db_share:
@ -591,7 +591,7 @@ def read_packing_list(
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripPackingListItemRead]:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id))
return [TripPackingListItemRead.serialize(i) for i in p_items]
@ -617,7 +617,7 @@ def create_packing_item(
data: TripPackingListItemCreate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripPackingListItemRead:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
item = TripPackingListItem(**data.model_dump(), trip_id=trip_id)
session.add(item)
session.commit()
@ -633,7 +633,7 @@ def update_packing_item(
p_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripPackingListItemRead:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
db_item = session.exec(
select(TripPackingListItem).where(
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
@ -660,7 +660,7 @@ def delete_packing_item(
p_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
item = session.exec(
select(TripPackingListItem).where(
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
@ -681,7 +681,7 @@ def read_checklist(
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripChecklistItemRead]:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id))
return [TripChecklistItemRead.serialize(i) for i in items]
@ -706,7 +706,7 @@ def create_checklist_item(
data: TripChecklistItemCreate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripChecklistItemRead:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
item = TripChecklistItem(**data.model_dump(), trip_id=trip_id)
session.add(item)
session.commit()
@ -722,7 +722,7 @@ def update_checklist_item(
id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripChecklistItemRead:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
db_item = session.exec(
select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id)
).one_or_none()
@ -747,7 +747,7 @@ def delete_checklist_item(
id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
item = session.exec(
select(TripChecklistItem).where(
TripChecklistItem.id == id,
@ -767,7 +767,7 @@ def delete_checklist_item(
def read_trip_members(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
) -> list[TripMemberRead]:
_verify_trip_member(session, trip_id, current_user)
_get_verified_trip(session, trip_id, current_user)
members: list[TripMemberRead] = []
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
@ -785,8 +785,7 @@ def invite_trip_member(
data: TripMemberCreate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripMemberRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if db_trip.user == data.user:
raise HTTPException(status_code=409, detail="The resource already exists")
@ -820,8 +819,7 @@ def delete_trip_member(
username: str,
current_user: Annotated[str, Depends(get_current_username)],
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
db_trip = _get_verified_trip(session, trip_id, current_user)
if current_user == db_trip.user and current_user == username:
raise HTTPException(status_code=400, detail="Bad request")
@ -840,11 +838,11 @@ def delete_trip_member(
# Set NULL to TripItem.paid_by for this username
trip_items = session.exec(
select(TripItem).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
select(TripItem.id).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
).all()
for item in trip_items:
item.paid_by = None
session.add_all(trip_items)
if trip_items:
session.exec(update(TripItem).where(TripItem.id.in_([id for id in trip_items])).values(paid_by=None))
session.delete(member)
session.commit()