2025-08-21 08:35:50 +02:00

784 lines
25 KiB
Python

from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import select
from ..config import settings
from ..deps import SessionDep, get_current_username
from ..models.models import (Image, Place, Trip, TripChecklistItem,
TripChecklistItemCreate, TripChecklistItemRead,
TripChecklistItemUpdate, TripCreate, TripDay,
TripDayBase, TripDayRead, TripInvitationRead,
TripItem, TripItemCreate, TripItemRead,
TripItemUpdate, TripMember, TripMemberCreate,
TripMemberRead, TripPackingListItem,
TripPackingListItemCreate,
TripPackingListItemRead,
TripPackingListItemUpdate, TripRead, TripReadBase,
TripShare, TripShareURL, TripUpdate, User)
from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image,
save_image_to_file, utc_now)
router = APIRouter(prefix="/api/trips", tags=["trips"])
def _get_trip_or_404(session, trip_id: int) -> Trip:
trip = session.get(Trip, trip_id)
if not trip:
raise HTTPException(status_code=404, detail="Not found")
return trip
def _trip_from_token_or_404(session, token: str) -> TripShare:
share = session.exec(select(TripShare).where(TripShare.token == token)).first()
if not share:
raise HTTPException(status_code=404, detail="Not found")
return share
def _trip_usernames(session, trip_id: int) -> set[str]:
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
members = session.exec(select(TripMember.user).where(TripMember.trip_id == trip_id)).all()
return {owner} | set(members)
def _can_access_trip(session, trip_id: int, username: str) -> bool:
# TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found
trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
if trip_owner == username:
return True
is_member = session.exec(
select(TripMember.id)
.where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None))
.limit(1)
).first()
return bool(is_member)
def _verify_trip_member(session, trip_id: int, username: str) -> None:
if not _can_access_trip(session, trip_id, username):
raise HTTPException(status_code=404, detail="Not found")
@router.get("", response_model=list[TripReadBase])
def read_trips(
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
) -> list[TripReadBase]:
trips = session.exec(
select(Trip)
.join(TripMember, isouter=True)
.where(
(Trip.user == current_user)
| ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None)))
)
.distinct()
)
return [TripReadBase.serialize(trip) for trip in trips]
@router.get("/invitations", response_model=list[TripInvitationRead])
def read_pending_invitations(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripInvitationRead]:
pending_inviattions = session.exec(
select(TripMember, Trip)
.join(Trip, Trip.id == TripMember.trip_id)
.where(
TripMember.user == current_user,
TripMember.joined_at.is_(None),
)
).all()
invitations: list[TripInvitationRead] = []
for tm, trip in pending_inviattions:
base = TripReadBase.serialize(trip)
invitations.append(
TripInvitationRead(
**base.model_dump(),
invited_by=tm.invited_by,
invited_at=tm.invited_at,
)
)
return invitations
@router.get("/invitations/pending", response_model=bool)
def has_pending_invitations(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> bool:
pending = session.exec(
select(TripMember.id).where(TripMember.user == current_user, TripMember.joined_at.is_(None)).limit(1)
).first()
return bool(pending)
@router.get("/{trip_id}", response_model=TripRead)
def read_trip(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
) -> TripRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
return TripRead.serialize(db_trip)
@router.post("", response_model=TripReadBase)
def create_trip(
trip: TripCreate, session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
) -> TripReadBase:
new_trip = Trip(name=trip.name, user=current_user)
if trip.image:
image_bytes = b64img_decode(trip.image)
filename = save_image_to_file(image_bytes, settings.TRIP_IMAGE_SIZE)
if not filename:
raise HTTPException(status_code=400, detail="Bad request")
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
new_trip.image_id = image.id
session.add(new_trip)
session.commit()
session.refresh(new_trip)
return TripReadBase.serialize(new_trip)
@router.put("/{trip_id}", response_model=TripRead)
def update_trip(
session: SessionDep,
trip_id: int,
trip: TripUpdate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived and (trip.archived is not False):
raise HTTPException(status_code=400, detail="Bad request")
trip_data = trip.model_dump(exclude_unset=True)
image_b64 = trip_data.pop("image", None)
if image_b64:
try:
image_bytes = b64img_decode(image_b64)
except Exception:
raise HTTPException(status_code=400, detail="Bad request")
filename = save_image_to_file(image_bytes, settings.TRIP_IMAGE_SIZE)
if not filename:
raise HTTPException(status_code=400, detail="Bad request")
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
if db_trip.image_id:
old_image = session.get(Image, db_trip.image_id)
try:
remove_image(old_image.filename)
session.delete(old_image)
db_trip.image_id = None
session.refresh(db_trip)
except Exception:
raise HTTPException(status_code=400, detail="Bad request")
db_trip.image_id = image.id
place_ids = trip_data.pop("place_ids", None)
if place_ids is not None: # Could be empty [], so 'in'
allowed_users = _trip_usernames(session, trip_id)
new_places = []
for place_id in place_ids:
db_place = session.get(Place, place_id)
if not db_place:
raise HTTPException(status_code=404, detail="Not found")
if db_place.user not in allowed_users:
raise HTTPException(status_code=403, detail="Place not accessible by trip members")
new_places.append(db_place)
db_trip.places = new_places
item_place_ids = {
item.place.id for day in db_trip.days for item in day.items if item.place is not None
}
invalid_place_ids = item_place_ids - set(place.id for place in db_trip.places)
if invalid_place_ids: # TripItem references a Place that Trip.places misses
raise HTTPException(status_code=400, detail="Bad Request")
for key, value in trip_data.items():
setattr(db_trip, key, value)
session.add(db_trip)
session.commit()
session.refresh(db_trip)
return TripRead.serialize(db_trip)
@router.delete("/{trip_id}")
def delete_trip(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
if db_trip.image:
try:
remove_image(db_trip.image.filename)
session.delete(db_trip.image)
except Exception:
raise HTTPException(
status_code=500,
detail="Roses are red, violets are blue, if you're reading this, I'm sorry for you",
)
session.delete(db_trip)
session.commit()
return {}
@router.post("/{trip_id}/days", response_model=TripDayRead)
def create_tripday(
td: TripDayBase,
trip_id: int,
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripDayRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
new_day = TripDay(label=td.label, trip_id=trip_id)
session.add(new_day)
session.commit()
session.refresh(new_day)
return TripDayRead.serialize(new_day)
@router.put("/{trip_id}/days/{day_id}", response_model=TripDayRead)
def update_tripday(
td: TripDayBase,
trip_id: int,
day_id: int,
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripDayRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
db_day = session.get(TripDay, day_id)
if not db_day or (db_day.trip_id != trip_id):
raise HTTPException(status_code=400, detail="Bad request")
td_data = td.model_dump(exclude_unset=True)
for key, value in td_data.items():
setattr(db_day, key, value)
session.add(db_day)
session.commit()
session.refresh(db_day)
return TripDayRead.serialize(db_day)
@router.delete("/{trip_id}/days/{day_id}")
def delete_tripday(
session: SessionDep,
trip_id: int,
day_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
db_day = session.get(TripDay, day_id)
if not db_day or (db_day.trip_id != trip_id):
raise HTTPException(status_code=400, detail="Bad request")
session.delete(db_day)
session.commit()
return {}
@router.post("/{trip_id}/days/{day_id}/items", response_model=TripItemRead)
def create_tripitem(
item: TripItemCreate,
trip_id: int,
day_id: int,
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripItemRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
db_day = session.get(TripDay, day_id)
if not db_day or (db_day.trip_id != trip_id):
raise HTTPException(status_code=400, detail="Bad request")
new_item = TripItem(
time=item.time,
text=item.text,
comment=item.comment,
lat=item.lat,
lng=item.lng,
day_id=day_id,
price=item.price,
status=item.status,
)
if item.place is not None:
place_in_trip = any(place.id == item.place for place in db_trip.places)
if not place_in_trip:
raise HTTPException(status_code=400, detail="Bad request")
new_item.place_id = item.place
session.add(new_item)
session.commit()
session.refresh(new_item)
return TripItemRead.serialize(new_item)
@router.put("/{trip_id}/days/{day_id}/items/{item_id}", response_model=TripItemRead)
def update_tripitem(
item: TripItemUpdate,
trip_id: int,
day_id: int,
item_id: int,
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripItemRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
db_day = session.get(TripDay, day_id)
if not db_day or (db_day.trip_id != trip_id):
raise HTTPException(status_code=400, detail="Bad request")
db_item = session.get(TripItem, item_id)
if not db_item or (db_item.day_id != day_id):
raise HTTPException(status_code=400, detail="Bad request")
item_data = item.model_dump(exclude_unset=True)
place_id = item_data.pop("place", None)
db_item.place_id = place_id
if place_id is not None:
place_in_trip = any(p.id == place_id for p in db_trip.places)
if not place_in_trip:
raise HTTPException(status_code=400, detail="Bad request")
for key, value in item_data.items():
setattr(db_item, key, value)
session.add(db_item)
session.commit()
session.refresh(db_item)
return TripItemRead.serialize(db_item)
@router.delete("/{trip_id}/days/{day_id}/items/{item_id}")
def delete_tripitem(
session: SessionDep,
trip_id: int,
day_id: int,
item_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.archived:
raise HTTPException(status_code=400, detail="Bad request")
db_day = session.get(TripDay, day_id)
if not db_day or (db_day.trip_id != trip_id):
raise HTTPException(status_code=400, detail="Bad request")
db_item = session.get(TripItem, item_id)
if not db_item or (db_item.day_id != day_id):
raise HTTPException(status_code=400, detail="Bad request")
session.delete(db_item)
session.commit()
return {}
@router.get("/shared/{token}", response_model=TripRead)
def read_shared_trip(
session: SessionDep,
token: str,
) -> TripRead:
db_trip = session.get(Trip, _trip_from_token_or_404(session, token).trip_id)
return TripRead.serialize(db_trip)
@router.get("/{trip_id}/share", response_model=TripShareURL)
def get_shared_trip_url(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripShareURL:
_verify_trip_member(session, trip_id, current_user)
share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if not share:
raise HTTPException(status_code=404, detail="Not found")
return {"url": f"/s/t/{share.token}"}
@router.post("/{trip_id}/share", response_model=TripShareURL)
def create_shared_trip(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripShareURL:
_verify_trip_member(session, trip_id, current_user)
shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if shared:
raise HTTPException(status_code=409, detail="The resource already exists")
token = generate_urlsafe()
trip_share = TripShare(token=token, trip_id=trip_id)
session.add(trip_share)
session.commit()
return {"url": f"/s/t/{token}"}
@router.delete("/{trip_id}/share")
def delete_shared_trip(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first()
if not db_share:
raise HTTPException(status_code=404, detail="Not found")
session.delete(db_share)
session.commit()
return {}
@router.get("/{trip_id}/packing", response_model=list[TripPackingListItemRead])
def read_packing_list(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripPackingListItemRead]:
_verify_trip_member(session, trip_id, current_user)
p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id))
return [TripPackingListItemRead.serialize(i) for i in p_items]
@router.get("/shared/{token}/packing", response_model=list[TripPackingListItemRead])
def read_shared_trip_packing_list(
session: SessionDep,
token: str,
) -> list[TripPackingListItemRead]:
p_items = session.exec(
select(TripPackingListItem).where(
TripPackingListItem.trip_id == _trip_from_token_or_404(session, token).trip_id
)
)
return [TripPackingListItemRead.serialize(i) for i in p_items]
@router.post("/{trip_id}/packing", response_model=TripPackingListItemRead)
def create_packing_item(
session: SessionDep,
trip_id: int,
data: TripPackingListItemCreate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripPackingListItemRead:
_verify_trip_member(session, trip_id, current_user)
item = TripPackingListItem(**data.model_dump(), trip_id=trip_id)
session.add(item)
session.commit()
session.refresh(item)
return TripPackingListItemRead.serialize(item)
@router.put("/{trip_id}/packing/{p_id}", response_model=TripPackingListItemRead)
def update_packing_item(
session: SessionDep,
p_item: TripPackingListItemUpdate,
trip_id: int,
p_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripPackingListItemRead:
_verify_trip_member(session, trip_id, current_user)
db_item = session.exec(
select(TripPackingListItem).where(
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
)
).one_or_none()
if not db_item:
raise HTTPException(status_code=404, detail="Not found")
item_data = p_item.model_dump(exclude_unset=True)
for key, value in item_data.items():
setattr(db_item, key, value)
session.add(db_item)
session.commit()
session.refresh(db_item)
return TripPackingListItemRead.serialize(db_item)
@router.delete("/{trip_id}/packing/{p_id}")
def delete_packing_item(
session: SessionDep,
trip_id: int,
p_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
item = session.exec(
select(TripPackingListItem).where(
TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id
)
).one_or_none()
if not item:
raise HTTPException(status_code=404, detail="Not found")
session.delete(item)
session.commit()
return {}
@router.get("/{trip_id}/checklist", response_model=list[TripChecklistItemRead])
def read_checklist(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> list[TripChecklistItemRead]:
_verify_trip_member(session, trip_id, current_user)
items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id))
return [TripChecklistItemRead.serialize(i) for i in items]
@router.get("/shared/{token}/checklist", response_model=list[TripChecklistItemRead])
def read_shared_trip_checklist(
session: SessionDep,
token: str,
) -> list[TripChecklistItemRead]:
items = session.exec(
select(TripChecklistItem).where(
TripChecklistItem.trip_id == _trip_from_token_or_404(session, token).trip_id
)
)
return [TripChecklistItemRead.serialize(i) for i in items]
@router.post("/{trip_id}/checklist", response_model=TripChecklistItemRead)
def create_checklist_item(
session: SessionDep,
trip_id: int,
data: TripChecklistItemCreate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripChecklistItemRead:
_verify_trip_member(session, trip_id, current_user)
item = TripChecklistItem(**data.model_dump(), trip_id=trip_id)
session.add(item)
session.commit()
session.refresh(item)
return TripChecklistItemRead.serialize(item)
@router.put("/{trip_id}/checklist/{id}", response_model=TripChecklistItemRead)
def update_checklist_item(
session: SessionDep,
item: TripChecklistItemUpdate,
trip_id: int,
id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripChecklistItemRead:
_verify_trip_member(session, trip_id, current_user)
db_item = session.exec(
select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id)
).one_or_none()
if not db_item:
raise HTTPException(status_code=404, detail="Not found")
item_data = item.model_dump(exclude_unset=True)
for key, value in item_data.items():
setattr(db_item, key, value)
session.add(db_item)
session.commit()
session.refresh(db_item)
return TripChecklistItemRead.serialize(db_item)
@router.delete("/{trip_id}/checklist/{id}")
def delete_checklist_item(
session: SessionDep,
trip_id: int,
id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
_verify_trip_member(session, trip_id, current_user)
item = session.exec(
select(TripChecklistItem).where(
TripChecklistItem.id == id,
TripChecklistItem.trip_id == trip_id,
)
).one_or_none()
if not item:
raise HTTPException(status_code=404, detail="Not found")
session.delete(item)
session.commit()
return {}
@router.get("/{trip_id}/members", response_model=list[TripMemberRead])
def read_trip_members(
session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)]
) -> list[TripMemberRead]:
_verify_trip_member(session, trip_id, current_user)
members: list[TripMemberRead] = []
owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first()
members.append(TripMemberRead(user=owner, invited_by=None, invited_at=None, joined_at=None))
db_members = session.exec(select(TripMember).where(TripMember.trip_id == trip_id)).all()
members.extend(TripMemberRead.serialize(m) for m in db_members)
return members
@router.post("/{trip_id}/members", response_model=TripMemberRead)
def invite_trip_member(
session: SessionDep,
trip_id: int,
data: TripMemberCreate,
current_user: Annotated[str, Depends(get_current_username)],
) -> TripMemberRead:
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if db_trip.user == data.user:
raise HTTPException(status_code=409, detail="The resource already exists")
exists = session.exec(
select(TripMember.id)
.where(
TripMember.trip_id == trip_id,
TripMember.user == data.user,
)
.limit(1)
).first()
if exists:
raise HTTPException(status_code=409, detail="The resource already exists")
db_user = session.get(User, data.user)
print(data.user, db_user)
if not db_user:
raise HTTPException(status_code=404, detail="Not found")
new_member = TripMember(trip_id=trip_id, user=data.user, invited_by=current_user)
session.add(new_member)
session.commit()
session.refresh(new_member)
return TripMemberRead.serialize(new_member)
@router.delete("/{trip_id}/members/{username}")
def delete_trip_member(
session: SessionDep,
trip_id: int,
username: str,
current_user: Annotated[str, Depends(get_current_username)],
):
print("yo")
db_trip = _get_trip_or_404(session, trip_id)
_verify_trip_member(session, trip_id, current_user)
if current_user == db_trip.user and current_user == username:
raise HTTPException(status_code=400, detail="Bad request")
if current_user != db_trip.user and current_user != username:
raise HTTPException(status_code=403, detail="Forbidden")
member = session.exec(
select(TripMember).where(
TripMember.user == username,
TripMember.trip_id == trip_id,
)
).one_or_none()
if not member:
raise HTTPException(status_code=404, detail="Not found")
session.delete(member)
session.commit()
return {}
@router.post("/{trip_id}/members/accept")
def accept_invite(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
db_member = session.exec(
select(TripMember).where(TripMember.trip_id == trip_id, TripMember.user == current_user)
).one_or_none()
if not db_member:
raise HTTPException(status_code=404, detail="Not found")
if db_member.joined_at:
raise HTTPException(status_code=409, detail="Already a member")
db_member.joined_at = utc_now()
session.add(db_member)
session.commit()
return {}
@router.post("/{trip_id}/members/decline")
def decline_invite(
session: SessionDep,
trip_id: int,
current_user: Annotated[str, Depends(get_current_username)],
):
db_member = session.exec(
select(TripMember).where(TripMember.trip_id == trip_id, TripMember.user == current_user)
).one_or_none()
if not db_member:
raise HTTPException(status_code=404, detail="Not found")
if db_member.joined_at:
raise HTTPException(status_code=409, detail="Already a member")
session.delete(db_member)
session.commit()
return {}