⚡ Eager loading, merge lookups, bulk sqlalchemy update
This commit is contained in:
parent
11cf3052f3
commit
c2058f0a39
@ -1,6 +1,8 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import update
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
@ -22,13 +24,6 @@ from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image,
|
|||||||
router = APIRouter(prefix="/api/trips", tags=["trips"])
|
router = APIRouter(prefix="/api/trips", tags=["trips"])
|
||||||
|
|
||||||
|
|
||||||
def _get_trip_or_404(session, trip_id: int) -> Trip:
|
|
||||||
trip = session.get(Trip, trip_id)
|
|
||||||
if not trip:
|
|
||||||
raise HTTPException(status_code=404, detail="Not found")
|
|
||||||
return trip
|
|
||||||
|
|
||||||
|
|
||||||
def _trip_from_token_or_404(session, token: str) -> TripShare:
|
def _trip_from_token_or_404(session, token: str) -> TripShare:
|
||||||
share = session.exec(select(TripShare).where(TripShare.token == token)).first()
|
share = session.exec(select(TripShare).where(TripShare.token == token)).first()
|
||||||
if not share:
|
if not share:
|
||||||
@ -42,23 +37,20 @@ def _trip_usernames(session, trip_id: int) -> set[str]:
|
|||||||
return {owner} | set(members)
|
return {owner} | set(members)
|
||||||
|
|
||||||
|
|
||||||
def _can_access_trip(session, trip_id: int, username: str) -> bool:
|
def _get_verified_trip(session, trip_id: int, username: str) -> Trip:
|
||||||
# TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found
|
# Merge of _verify_trip_member(+_can_access_trip) + _get_trip_or_404
|
||||||
trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
|
# Returns a Trip if: it exists and username is a TripMember or trip.user (owner)
|
||||||
if trip_owner == username:
|
trip = session.exec(
|
||||||
return True
|
select(Trip)
|
||||||
|
.outerjoin(TripMember)
|
||||||
is_member = session.exec(
|
.where(
|
||||||
select(TripMember.id)
|
Trip.id == trip_id,
|
||||||
.where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None))
|
(Trip.user == username) | ((TripMember.user == username) & (TripMember.joined_at.is_not(None))),
|
||||||
.limit(1)
|
)
|
||||||
).first()
|
).first()
|
||||||
return bool(is_member)
|
if not trip:
|
||||||
|
|
||||||
|
|
||||||
def _verify_trip_member(session, trip_id: int, username: str) -> None:
|
|
||||||
if not _can_access_trip(session, trip_id, username):
|
|
||||||
raise HTTPException(status_code=404, detail="Not found")
|
raise HTTPException(status_code=404, detail="Not found")
|
||||||
|
return trip
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[TripReadBase])
|
@router.get("", response_model=list[TripReadBase])
|
||||||
@ -120,8 +112,24 @@ def has_pending_invitations(
|
|||||||
def read_trip(
|
def read_trip(
|
||||||
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
||||||
) -> TripRead:
|
) -> TripRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = session.exec(
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
select(Trip)
|
||||||
|
.options(
|
||||||
|
selectinload(Trip.days).selectinload(TripDay.items),
|
||||||
|
selectinload(Trip.places),
|
||||||
|
selectinload(Trip.image),
|
||||||
|
selectinload(Trip.memberships),
|
||||||
|
)
|
||||||
|
.outerjoin(TripMember)
|
||||||
|
.where(
|
||||||
|
Trip.id == trip_id,
|
||||||
|
(Trip.user == current_user)
|
||||||
|
| ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None))),
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not db_trip:
|
||||||
|
raise HTTPException(status_code=404, detail="Not found")
|
||||||
return TripRead.serialize(db_trip)
|
return TripRead.serialize(db_trip)
|
||||||
|
|
||||||
|
|
||||||
@ -156,8 +164,7 @@ def update_trip(
|
|||||||
trip: TripUpdate,
|
trip: TripUpdate,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripRead:
|
) -> TripRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived and (trip.archived is not False):
|
if db_trip.archived and (trip.archived is not False):
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -225,8 +232,7 @@ def update_trip(
|
|||||||
def delete_trip(
|
def delete_trip(
|
||||||
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
||||||
):
|
):
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -252,14 +258,14 @@ def get_trip_balance(
|
|||||||
trip_id: int,
|
trip_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
|
|
||||||
members = _trip_usernames(session, trip_id)
|
members = _trip_usernames(session, trip_id)
|
||||||
if len(members) < 2:
|
if len(members) < 2:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
|
|
||||||
trip_items = session.exec(
|
trip_items = session.exec(
|
||||||
select(TripItem)
|
select(TripItem.price, TripItem.paid_by)
|
||||||
.join(TripDay)
|
.join(TripDay)
|
||||||
.where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None))
|
.where(TripDay.trip_id == trip_id, TripItem.price.is_not(None), TripItem.paid_by.is_not(None))
|
||||||
).all()
|
).all()
|
||||||
@ -281,8 +287,7 @@ def create_tripday(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripDayRead:
|
) -> TripDayRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -303,8 +308,7 @@ def update_tripday(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripDayRead:
|
) -> TripDayRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -330,8 +334,7 @@ def delete_tripday(
|
|||||||
day_id: int,
|
day_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -353,8 +356,7 @@ def create_tripitem(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripItemRead:
|
) -> TripItemRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -415,8 +417,7 @@ def update_tripitem(
|
|||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripItemRead:
|
) -> TripItemRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -506,8 +507,7 @@ def delete_tripitem(
|
|||||||
item_id: int,
|
item_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.archived:
|
if db_trip.archived:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -540,7 +540,7 @@ def get_shared_trip_url(
|
|||||||
trip_id: int,
|
trip_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripShareURL:
|
) -> TripShareURL:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
|
|
||||||
share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
||||||
if not share:
|
if not share:
|
||||||
@ -555,7 +555,7 @@ def create_shared_trip(
|
|||||||
trip_id: int,
|
trip_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripShareURL:
|
) -> TripShareURL:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
|
|
||||||
shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
||||||
if shared:
|
if shared:
|
||||||
@ -574,7 +574,7 @@ def delete_shared_trip(
|
|||||||
trip_id: int,
|
trip_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
|
|
||||||
db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
|
||||||
if not db_share:
|
if not db_share:
|
||||||
@ -591,7 +591,7 @@ def read_packing_list(
|
|||||||
trip_id: int,
|
trip_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> list[TripPackingListItemRead]:
|
) -> list[TripPackingListItemRead]:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id))
|
p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id))
|
||||||
|
|
||||||
return [TripPackingListItemRead.serialize(i) for i in p_items]
|
return [TripPackingListItemRead.serialize(i) for i in p_items]
|
||||||
@ -617,7 +617,7 @@ def create_packing_item(
|
|||||||
data: TripPackingListItemCreate,
|
data: TripPackingListItemCreate,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripPackingListItemRead:
|
) -> TripPackingListItemRead:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
item = TripPackingListItem(**data.model_dump(), trip_id=trip_id)
|
item = TripPackingListItem(**data.model_dump(), trip_id=trip_id)
|
||||||
session.add(item)
|
session.add(item)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -633,7 +633,7 @@ def update_packing_item(
|
|||||||
p_id: int,
|
p_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripPackingListItemRead:
|
) -> TripPackingListItemRead:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
db_item = session.exec(
|
db_item = session.exec(
|
||||||
select(TripPackingListItem).where(
|
select(TripPackingListItem).where(
|
||||||
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
|
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
|
||||||
@ -660,7 +660,7 @@ def delete_packing_item(
|
|||||||
p_id: int,
|
p_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
item = session.exec(
|
item = session.exec(
|
||||||
select(TripPackingListItem).where(
|
select(TripPackingListItem).where(
|
||||||
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
|
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
|
||||||
@ -681,7 +681,7 @@ def read_checklist(
|
|||||||
trip_id: int,
|
trip_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> list[TripChecklistItemRead]:
|
) -> list[TripChecklistItemRead]:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id))
|
items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id))
|
||||||
return [TripChecklistItemRead.serialize(i) for i in items]
|
return [TripChecklistItemRead.serialize(i) for i in items]
|
||||||
|
|
||||||
@ -706,7 +706,7 @@ def create_checklist_item(
|
|||||||
data: TripChecklistItemCreate,
|
data: TripChecklistItemCreate,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripChecklistItemRead:
|
) -> TripChecklistItemRead:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
item = TripChecklistItem(**data.model_dump(), trip_id=trip_id)
|
item = TripChecklistItem(**data.model_dump(), trip_id=trip_id)
|
||||||
session.add(item)
|
session.add(item)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -722,7 +722,7 @@ def update_checklist_item(
|
|||||||
id: int,
|
id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripChecklistItemRead:
|
) -> TripChecklistItemRead:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
db_item = session.exec(
|
db_item = session.exec(
|
||||||
select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id)
|
select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id)
|
||||||
).one_or_none()
|
).one_or_none()
|
||||||
@ -747,7 +747,7 @@ def delete_checklist_item(
|
|||||||
id: int,
|
id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
item = session.exec(
|
item = session.exec(
|
||||||
select(TripChecklistItem).where(
|
select(TripChecklistItem).where(
|
||||||
TripChecklistItem.id == id,
|
TripChecklistItem.id == id,
|
||||||
@ -767,7 +767,7 @@ def delete_checklist_item(
|
|||||||
def read_trip_members(
|
def read_trip_members(
|
||||||
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
|
||||||
) -> list[TripMemberRead]:
|
) -> list[TripMemberRead]:
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
_get_verified_trip(session, trip_id, current_user)
|
||||||
|
|
||||||
members: list[TripMemberRead] = []
|
members: list[TripMemberRead] = []
|
||||||
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
|
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
|
||||||
@ -785,8 +785,7 @@ def invite_trip_member(
|
|||||||
data: TripMemberCreate,
|
data: TripMemberCreate,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> TripMemberRead:
|
) -> TripMemberRead:
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if db_trip.user == data.user:
|
if db_trip.user == data.user:
|
||||||
raise HTTPException(status_code=409, detail="The resource already exists")
|
raise HTTPException(status_code=409, detail="The resource already exists")
|
||||||
@ -820,8 +819,7 @@ def delete_trip_member(
|
|||||||
username: str,
|
username: str,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
):
|
):
|
||||||
db_trip = _get_trip_or_404(session, trip_id)
|
db_trip = _get_verified_trip(session, trip_id, current_user)
|
||||||
_verify_trip_member(session, trip_id, current_user)
|
|
||||||
|
|
||||||
if current_user == db_trip.user and current_user == username:
|
if current_user == db_trip.user and current_user == username:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
@ -840,11 +838,11 @@ def delete_trip_member(
|
|||||||
|
|
||||||
# Set NULL to TripItem.paid_by for this username
|
# Set NULL to TripItem.paid_by for this username
|
||||||
trip_items = session.exec(
|
trip_items = session.exec(
|
||||||
select(TripItem).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
|
select(TripItem.id).join(TripDay).where(TripDay.trip_id == trip_id, TripItem.paid_by == username)
|
||||||
).all()
|
).all()
|
||||||
for item in trip_items:
|
|
||||||
item.paid_by = None
|
if trip_items:
|
||||||
session.add_all(trip_items)
|
session.exec(update(TripItem).where(TripItem.id.in_([id for id in trip_items])).values(paid_by=None))
|
||||||
|
|
||||||
session.delete(member)
|
session.delete(member)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user