⚡ Eager loading, merge lookups, bulk sqlalchemy update
This commit is contained in:
parent
11cf3052f3
commit
c2058f0a39
@ -1,6 +1,8 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from ..config import settings
|
||||
@ -22,13 +24,6 @@ from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image,
|
||||
router = APIRouter(prefix="/api/trips", tags=["trips"])
|
||||
|
||||
|
||||
def _get_trip_or_404(session, trip_id: int) -> Trip:
|
||||
trip = session.get(Trip, trip_id)
|
||||
if not trip:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return trip
|
||||
|
||||
|
||||
def _trip_from_token_or_404(session, token: str) -> TripShare:
|
||||
share = session.exec(select(TripShare).where(TripShare.token == token)).first()
|
||||
if not share:
|
||||
@ -42,23 +37,20 @@ def _trip_usernames(session, trip_id: int) -> set[str]:
|
||||
return {owner} | set(members)
|
||||
|
||||
|
||||
def _can_access_trip(session, trip_id: int, username: str) -> bool:
|
||||
# TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found
|
||||
trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
|
||||
if trip_owner == username:
|
||||
return True
|
||||
|
||||
is_member = session.exec(
|
||||
select(TripMember.id)
|
||||
.where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None))
|
||||
.limit(1)
|
||||
def _get_verified_trip(session, trip_id: int, username: str) -> Trip:
|
||||
# Merge of _verify_trip_member(+_can_access_trip) + _get_trip_or_404
|
||||
# Returns a Trip if: it exists and username is a TripMember or trip.user (owner)
|
||||
trip = session.exec(
|
||||
select(Trip)
|
||||
.outerjoin(TripMember)
|
||||
.where(
|
||||
Trip.id == trip_id,
|
||||
(Trip.user == username) | ((TripMember.user == username) & (TripMember.joined_at.is_not(None))),
|
||||
)
|
||||
).first()
|
||||
return bool(is_member)
|
||||
|
||||
|
||||
def _verify_trip_member(session, trip_id: int, username: str) -> None:
|
||||
if not _can_access_trip(session, trip_id, username):
|
||||
if not trip:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return trip
|
||||
|
||||
|
||||
@router.get("", response_model=list[TripReadBase])
|
||||
@ -120,8 +112,24 @@ def has_pending_invitations(
|
||||
def read_trip(
|
||||
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
||||
) -> TripRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = session.exec(
|
||||
select(Trip)
|
||||
.options(
|
||||
selectinload(Trip.days).selectinload(TripDay.items),
|
||||
selectinload(Trip.places),
|
||||
selectinload(Trip.image),
|
||||
selectinload(Trip.memberships),
|
||||
)
|
||||
.outerjoin(TripMember)
|
||||
.where(
|
||||
Trip.id == trip_id,
|
||||
(Trip.user == current_user)
|
||||
| ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None))),
|
||||
)
|
||||
).first()
|
||||
|
||||
if not db_trip:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return TripRead.serialize(db_trip)
|
||||
|
||||
|
||||
@ -156,8 +164,7 @@ def update_trip(
|
||||
trip: TripUpdate,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived and (trip.archived is not False):
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -225,8 +232,7 @@ def update_trip(
|
||||
def delete_trip(
|
||||
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
||||
):
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -252,14 +258,14 @@ def get_trip_balance(
|
||||
trip_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
members = _trip_usernames(session, trip_id)
|
||||
if len(members) < 2:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
|
||||
trip_items = session.exec(
|
||||
select(TripItem)
|
||||
select(TripItem.price, TripItem.paid_by)
|
||||
.join(TripDay)
|
||||
.where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None))
|
||||
).all()
|
||||
@ -281,8 +287,7 @@ def create_tripday(
|
||||
session: SessionDep,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripDayRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -303,8 +308,7 @@ def update_tripday(
|
||||
session: SessionDep,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripDayRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -330,8 +334,7 @@ def delete_tripday(
|
||||
day_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -353,8 +356,7 @@ def create_tripitem(
|
||||
session: SessionDep,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripItemRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -415,8 +417,7 @@ def update_tripitem(
|
||||
session: SessionDep,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripItemRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -506,8 +507,7 @@ def delete_tripitem(
|
||||
item_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.archived:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -540,7 +540,7 @@ def get_shared_trip_url(
|
||||
trip_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripShareURL:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
||||
if not share:
|
||||
@ -555,7 +555,7 @@ def create_shared_trip(
|
||||
trip_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripShareURL:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
||||
if shared:
|
||||
@ -574,7 +574,7 @@ def delete_shared_trip(
|
||||
trip_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
||||
if not db_share:
|
||||
@ -591,7 +591,7 @@ def read_packing_list(
|
||||
trip_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> list[TripPackingListItemRead]:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id))
|
||||
|
||||
return [TripPackingListItemRead.serialize(i) for i in p_items]
|
||||
@ -617,7 +617,7 @@ def create_packing_item(
|
||||
data: TripPackingListItemCreate,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripPackingListItemRead:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
item = TripPackingListItem(**data.model_dump(), trip_id=trip_id)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
@ -633,7 +633,7 @@ def update_packing_item(
|
||||
p_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripPackingListItemRead:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
db_item = session.exec(
|
||||
select(TripPackingListItem).where(
|
||||
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
|
||||
@ -660,7 +660,7 @@ def delete_packing_item(
|
||||
p_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
item = session.exec(
|
||||
select(TripPackingListItem).where(
|
||||
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
|
||||
@ -681,7 +681,7 @@ def read_checklist(
|
||||
trip_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> list[TripChecklistItemRead]:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id))
|
||||
return [TripChecklistItemRead.serialize(i) for i in items]
|
||||
|
||||
@ -706,7 +706,7 @@ def create_checklist_item(
|
||||
data: TripChecklistItemCreate,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripChecklistItemRead:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
item = TripChecklistItem(**data.model_dump(), trip_id=trip_id)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
@ -722,7 +722,7 @@ def update_checklist_item(
|
||||
id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripChecklistItemRead:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
db_item = session.exec(
|
||||
select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id)
|
||||
).one_or_none()
|
||||
@ -747,7 +747,7 @@ def delete_checklist_item(
|
||||
id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
item = session.exec(
|
||||
select(TripChecklistItem).where(
|
||||
TripChecklistItem.id == id,
|
||||
@ -767,7 +767,7 @@ def delete_checklist_item(
|
||||
def read_trip_members(
|
||||
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
||||
) -> list[TripMemberRead]:
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
_get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
members: list[TripMemberRead] = []
|
||||
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
|
||||
@ -785,8 +785,7 @@ def invite_trip_member(
|
||||
data: TripMemberCreate,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> TripMemberRead:
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if db_trip.user == data.user:
|
||||
raise HTTPException(status_code=409, detail="The resource already exists")
|
||||
@ -820,8 +819,7 @@ def delete_trip_member(
|
||||
username: str,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
):
|
||||
db_trip = _get_trip_or_404(session, trip_id)
|
||||
_verify_trip_member(session, trip_id, current_user)
|
||||
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||
|
||||
if current_user == db_trip.user and current_user == username:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
@ -840,11 +838,11 @@ def delete_trip_member(
|
||||
|
||||
# Set NULL to TripItem.paid_by for this username
|
||||
trip_items = session.exec(
|
||||
select(TripItem).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
|
||||
select(TripItem.id).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
|
||||
).all()
|
||||
for item in trip_items:
|
||||
item.paid_by = None
|
||||
session.add_all(trip_items)
|
||||
|
||||
if trip_items:
|
||||
session.exec(update(TripItem).where(TripItem.id.in_([id for id in trip_items])).values(paid_by=None))
|
||||
|
||||
session.delete(member)
|
||||
session.commit()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user