diff --git a/backend/trip/alembic/versions/26c89b7466f2_trip_multi_users.py b/backend/trip/alembic/versions/26c89b7466f2_trip_multi_users.py new file mode 100644 index 0000000..3c6c825 --- /dev/null +++ b/backend/trip/alembic/versions/26c89b7466f2_trip_multi_users.py @@ -0,0 +1,91 @@ +"""Trip multi-users + +Revision ID: 26c89b7466f2 +Revises: 60a9bb641d8a +Create Date: 2025-08-18 23:19:37.457354 + +""" + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from alembic import op + +# revision identifiers, used by Alembic. +revision = "26c89b7466f2" +down_revision = "60a9bb641d8a" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "tripmember", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("invited_by", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("invited_at", sa.DateTime(), nullable=False), + sa.Column("joined_at", sa.DateTime(), nullable=True), + sa.Column("trip_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["invited_by"], + ["user.username"], + name=op.f("fk_tripmember_invited_by_user"), + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["trip_id"], ["trip.id"], name=op.f("fk_tripmember_trip_id_trip"), ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["user"], ["user.username"], name=op.f("fk_tripmember_user_user"), ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_tripmember")), + ) + with op.batch_alter_table("tripchecklistitem", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("fk_tripchecklistitem_user_user"), type_="foreignkey") + batch_op.drop_column("user") + + with op.batch_alter_table("trippackinglistitem", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("fk_trippackinglistitem_user_user"), type_="foreignkey") + batch_op.drop_column("user") + + with op.batch_alter_table("tripitem", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("fk_tripitem_day_id_tripday"), type_="foreignkey") + + with op.batch_alter_table("tripday", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("fk_tripday_user_user"), type_="foreignkey") + batch_op.drop_column("user") + + with op.batch_alter_table("tripitem", schema=None) as batch_op: + batch_op.create_foreign_key( + batch_op.f("fk_tripitem_day_id_tripday"), + "tripday", + ["day_id"], + ["id"], + ondelete="CASCADE", + ) + + +def downgrade(): + with op.batch_alter_table("trippackinglistitem", schema=None) as batch_op: + batch_op.add_column(sa.Column("user", sa.VARCHAR(), nullable=False)) + batch_op.create_foreign_key( + batch_op.f("fk_trippackinglistitem_user_user"), + "user", + ["user"], + ["username"], + ondelete="CASCADE", + ) + + with op.batch_alter_table("tripday", schema=None) as batch_op: + batch_op.add_column(sa.Column("user", sa.VARCHAR(), nullable=False)) + batch_op.create_foreign_key( + batch_op.f("fk_tripday_user_user"), "user", ["user"], ["username"], ondelete="CASCADE" + ) + + with op.batch_alter_table("tripchecklistitem", schema=None) as batch_op: + batch_op.add_column(sa.Column("user", sa.VARCHAR(), nullable=False)) + batch_op.create_foreign_key( + batch_op.f("fk_tripchecklistitem_user_user"), "user", ["user"], ["username"], ondelete="CASCADE" + ) + + op.drop_table("tripmember") diff --git a/backend/trip/models/models.py b/backend/trip/models/models.py index 3ebdbb5..c09e420 100644 --- a/backend/trip/models/models.py +++ b/backend/trip/models/models.py @@ -261,6 +261,7 @@ class Trip(TripBase, table=True): shares: list["TripShare"] = Relationship(back_populates="trip", cascade_delete=True) packing_items: list["TripPackingListItem"] = Relationship(back_populates="trip", cascade_delete=True) checklist_items: list["TripChecklistItem"] = Relationship(back_populates="trip", cascade_delete=True) + memberships: list["TripMember"] = Relationship(back_populates="trip", cascade_delete=True) class TripCreate(TripBase): @@ -279,6 +280,7 @@ class TripReadBase(TripBase): image: str | None image_id: int | None days: int + collaborators: list["TripMemberRead"] @classmethod def serialize(cls, obj: Trip) -> "TripRead": @@ -289,6 +291,7 @@ class TripReadBase(TripBase): image=_prefix_assets_url(obj.image.filename) if obj.image else None, image_id=obj.image_id, days=len(obj.days), + collaborators=[TripMemberRead.serialize(m) for m in obj.memberships], ) @@ -298,6 +301,7 @@ class TripRead(TripBase): image_id: int | None days: list["TripDayRead"] places: list["PlaceRead"] + collaborators: list["TripMemberRead"] @classmethod def serialize(cls, obj: Trip) -> "TripRead": @@ -309,17 +313,49 @@ class TripRead(TripBase): image_id=obj.image_id, days=[TripDayRead.serialize(day) for day in obj.days], places=[PlaceRead.serialize(place) for place in obj.places], + collaborators=[TripMemberRead.serialize(m) for m in obj.memberships], ) +class TripMember(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + user: str = Field(foreign_key="user.username", ondelete="CASCADE") + invited_by: str | None = Field(default=None, foreign_key="user.username", ondelete="SET NULL") + invited_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + joined_at: datetime | None = None + + trip_id: int = Field(foreign_key="trip.id", ondelete="CASCADE") + trip: Trip | None = Relationship(back_populates="memberships") + + +class TripMemberCreate(BaseModel): + user: str + + +class TripMemberRead(BaseModel): + user: str + invited_by: str | None = None + invited_at: datetime | None = None + joined_at: datetime | None = None + + @classmethod + def serialize(cls, obj: TripMember) -> "TripMemberRead": + return cls( + user=obj.user, invited_by=obj.invited_by, invited_at=obj.invited_at, joined_at=obj.joined_at + ) + + +class TripInvitationRead(TripReadBase): + invited_by: str | None = None + invited_at: datetime + + class TripDayBase(SQLModel): label: str class TripDay(TripDayBase, table=True): id: int | None = Field(default=None, primary_key=True) - user: str = Field(foreign_key="user.username", ondelete="CASCADE") - trip_id: int = Field(foreign_key="trip.id", ondelete="CASCADE") trip: Trip | None = Relationship(back_populates="days") @@ -420,7 +456,6 @@ class TripPackingListItemBase(SQLModel): class TripPackingListItem(TripPackingListItemBase, table=True): id: int | None = Field(default=None, primary_key=True) - user: str = Field(foreign_key="user.username", ondelete="CASCADE") trip_id: int = Field(foreign_key="trip.id", ondelete="CASCADE") trip: Trip | None = Relationship(back_populates="packing_items") diff --git a/backend/trip/routers/trips.py b/backend/trip/routers/trips.py index afd7a7b..04d87a0 100644 --- a/backend/trip/routers/trips.py +++ b/backend/trip/routers/trips.py @@ -8,33 +8,120 @@ from ..deps import SessionDep, get_current_username from ..models.models import (Image, Place, Trip, TripChecklistItem, TripChecklistItemCreate, TripChecklistItemRead, TripChecklistItemUpdate, TripCreate, TripDay, - TripDayBase, TripDayRead, TripItem, - TripItemCreate, TripItemRead, TripItemUpdate, - TripPackingListItem, TripPackingListItemCreate, - TripPackingListItemRead, TripPackingListItemUpdate, - TripRead, TripReadBase, TripShare, - TripShareURL, TripUpdate) -from ..security import verify_exists_and_owns + TripDayBase, TripDayRead, TripInvitationRead, + TripItem, TripItemCreate, TripItemRead, + TripItemUpdate, TripMember, TripMemberCreate, + TripMemberRead, TripPackingListItem, + TripPackingListItemCreate, + TripPackingListItemRead, + TripPackingListItemUpdate, TripRead, TripReadBase, + TripShare, TripShareURL, TripUpdate, User) from ..utils.utils import (b64img_decode, generate_urlsafe, remove_image, - save_image_to_file) + save_image_to_file, utc_now) router = APIRouter(prefix="/api/trips", tags=["trips"]) +def _get_trip_or_404(session, trip_id: int) -> Trip: + trip = session.get(Trip, trip_id) + if not trip: + raise HTTPException(status_code=404, detail="Not found") + return trip + + +def _trip_from_token_or_404(session, token: str) -> TripShare: + share = session.exec(select(TripShare).where(TripShare.token == token)).first() + if not share: + raise HTTPException(status_code=404, detail="Not found") + return share + + +def _trip_usernames(session, trip_id: int) -> set[str]: + owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() + members = session.exec(select(TripMember.user).where(TripMember.trip_id == trip_id)).all() + return {owner} | set(members) + + +def _can_access_trip(session, trip_id: int, username: str) -> bool: + # TODO: Optimize, if trip does not exist, trip_owner is none and we keep iterating on members despite trip not found + trip_owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() + if trip_owner == username: + return True + + is_member = session.exec( + select(TripMember.id) + .where(TripMember.trip_id == trip_id, TripMember.user == username, TripMember.joined_at.is_not(None)) + .limit(1) + ).first() + return bool(is_member) + + +def _verify_trip_member(session, trip_id: int, username: str) -> None: + if not _can_access_trip(session, trip_id, username): + raise HTTPException(status_code=404, detail="Not found") + + @router.get("", response_model=list[TripReadBase]) def read_trips( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)] ) -> list[TripReadBase]: - trips = session.exec(select(Trip).filter(Trip.user == current_user)) + trips = session.exec( + select(Trip) + .join(TripMember, isouter=True) + .where( + (Trip.user == current_user) + | ((TripMember.user == current_user) & (TripMember.joined_at.is_not(None))) + ) + .distinct() + ) return [TripReadBase.serialize(trip) for trip in trips] +@router.get("/invitations", response_model=list[TripInvitationRead]) +def read_pending_invitations( + session: SessionDep, + current_user: Annotated[str, Depends(get_current_username)], +) -> list[TripInvitationRead]: + pending_inviattions = session.exec( + select(TripMember, Trip) + .join(Trip, Trip.id == TripMember.trip_id) + .where( + TripMember.user == current_user, + TripMember.joined_at.is_(None), + ) + ).all() + + invitations: list[TripInvitationRead] = [] + for tm, trip in pending_inviattions: + base = TripReadBase.serialize(trip) + invitations.append( + TripInvitationRead( + **base.model_dump(), + invited_by=tm.invited_by, + invited_at=tm.invited_at, + ) + ) + + return invitations + + +@router.get("/invitations/pending", response_model=bool) +def has_pending_invitations( + session: SessionDep, + current_user: Annotated[str, Depends(get_current_username)], +) -> bool: + pending = session.exec( + select(TripMember.id).where(TripMember.user == current_user, TripMember.joined_at.is_(None)).limit(1) + ).first() + return bool(pending) + + @router.get("/{trip_id}", response_model=TripRead) def read_trip( session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] ) -> TripRead: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) return TripRead.serialize(db_trip) @@ -42,10 +129,7 @@ def read_trip( def create_trip( trip: TripCreate, session: SessionDep, current_user: Annotated[str, Depends(get_current_username)] ) -> TripReadBase: - new_trip = Trip( - name=trip.name, - user=current_user, - ) + new_trip = Trip(name=trip.name, user=current_user) if trip.image: image_bytes = b64img_decode(trip.image) @@ -55,17 +139,10 @@ def create_trip( image = Image(filename=filename, user=current_user) session.add(image) - session.commit() + session.flush() session.refresh(image) new_trip.image_id = image.id - if trip.place_ids: - for place_id in trip.place_ids: - db_place = session.get(Place, place_id) - verify_exists_and_owns(current_user, db_place) - session.add(TripPlaceLink(trip_id=new_trip.id, place_id=db_place.id)) - session.commit() - session.add(new_trip) session.commit() session.refresh(new_trip) @@ -79,8 +156,8 @@ def update_trip( trip: TripUpdate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripRead: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived and (trip.archived is not False): raise HTTPException(status_code=400, detail="Bad request") @@ -100,8 +177,8 @@ def update_trip( image = Image(filename=filename, user=current_user) session.add(image) - session.commit() session.refresh(image) + session.flush() if db_trip.image_id: old_image = session.get(Image, db_trip.image_id) @@ -117,11 +194,16 @@ def update_trip( place_ids = trip_data.pop("place_ids", None) if place_ids is not None: # Could be empty [], so 'in' - db_trip.places.clear() + allowed_users = _trip_usernames(session, trip_id) + new_places = [] for place_id in place_ids: db_place = session.get(Place, place_id) - verify_exists_and_owns(current_user, db_place) - db_trip.places.append(db_place) + if not db_place: + raise HTTPException(status_code=404, detail="Not found") + if db_place.user not in allowed_users: + raise HTTPException(status_code=403, detail="Place not accessible by trip members") + new_places.append(db_place) + db_trip.places = new_places item_place_ids = { item.place.id for day in db_trip.days for item in day.items if item.place is not None @@ -143,8 +225,8 @@ def update_trip( def delete_trip( session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] ): - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") @@ -171,13 +253,13 @@ def create_tripday( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripDayRead: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") - new_day = TripDay(label=td.label, trip_id=trip_id, user=current_user) + new_day = TripDay(label=td.label, trip_id=trip_id) session.add(new_day) session.commit() @@ -193,15 +275,14 @@ def update_tripday( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripDayRead: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") db_day = session.get(TripDay, day_id) - verify_exists_and_owns(current_user, db_day) - if db_day.trip_id != trip_id: + if not db_day or (db_day.trip_id != trip_id): raise HTTPException(status_code=400, detail="Bad request") td_data = td.model_dump(exclude_unset=True) @@ -221,15 +302,14 @@ def delete_tripday( day_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") db_day = session.get(TripDay, day_id) - verify_exists_and_owns(current_user, db_day) - if db_day.trip_id != trip_id: + if not db_day or (db_day.trip_id != trip_id): raise HTTPException(status_code=400, detail="Bad request") session.delete(db_day) @@ -245,14 +325,14 @@ def create_tripitem( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripItemRead: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") db_day = session.get(TripDay, day_id) - if db_day.trip_id != trip_id: + if not db_day or (db_day.trip_id != trip_id): raise HTTPException(status_code=400, detail="Bad request") new_item = TripItem( @@ -266,7 +346,7 @@ def create_tripitem( status=item.status, ) - if item.place and item.place != "": + if item.place is not None: place_in_trip = any(place.id == item.place for place in db_trip.places) if not place_in_trip: raise HTTPException(status_code=400, detail="Bad request") @@ -287,18 +367,18 @@ def update_tripitem( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)], ) -> TripItemRead: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") db_day = session.get(TripDay, day_id) - if db_day.trip_id != trip_id: + if not db_day or (db_day.trip_id != trip_id): raise HTTPException(status_code=400, detail="Bad request") db_item = session.get(TripItem, item_id) - if db_item.day_id != day_id: + if not db_item or (db_item.day_id != day_id): raise HTTPException(status_code=400, detail="Bad request") item_data = item.model_dump(exclude_unset=True) @@ -327,18 +407,18 @@ def delete_tripitem( item_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) if db_trip.archived: raise HTTPException(status_code=400, detail="Bad request") db_day = session.get(TripDay, day_id) - if db_day.trip_id != trip_id: + if not db_day or (db_day.trip_id != trip_id): raise HTTPException(status_code=400, detail="Bad request") db_item = session.get(TripItem, item_id) - if db_item.day_id != day_id: + if not db_item or (db_item.day_id != day_id): raise HTTPException(status_code=400, detail="Bad request") session.delete(db_item) @@ -351,11 +431,7 @@ def read_shared_trip( session: SessionDep, token: str, ) -> TripRead: - share = session.exec(select(TripShare).where(TripShare.token == token)).first() - if not share: - raise HTTPException(status_code=404, detail="Not found") - - db_trip = session.get(Trip, share.trip_id) + db_trip = session.get(Trip, _trip_from_token_or_404(session, token).trip_id) return TripRead.serialize(db_trip) @@ -365,8 +441,7 @@ def get_shared_trip_url( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripShareURL: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + _verify_trip_member(session, trip_id, current_user) share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() if not share: @@ -381,15 +456,14 @@ def create_shared_trip( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripShareURL: - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + _verify_trip_member(session, trip_id, current_user) shared = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() if shared: raise HTTPException(status_code=409, detail="The resource already exists") token = generate_urlsafe() - trip_share = TripShare(token=token, trip_id=trip_id, user=current_user) + trip_share = TripShare(token=token, trip_id=trip_id) session.add(trip_share) session.commit() return {"url": f"/s/t/{token}"} @@ -401,8 +475,7 @@ def delete_shared_trip( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ): - db_trip = session.get(Trip, trip_id) - verify_exists_and_owns(current_user, db_trip) + _verify_trip_member(session, trip_id, current_user) db_share = session.exec(select(TripShare).where(TripShare.trip_id == trip_id)).first() if not db_share: @@ -419,15 +492,25 @@ def read_packing_list( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> list[TripPackingListItemRead]: - p_items = session.exec( - select(TripPackingListItem) - .where(TripPackingListItem.trip_id == trip_id, TripPackingListItem.user == current_user) - .order_by(TripPackingListItem.id.asc()) - ).all() + _verify_trip_member(session, trip_id, current_user) + p_items = session.exec(select(TripPackingListItem).where(TripPackingListItem.trip_id == trip_id)) return [TripPackingListItemRead.serialize(i) for i in p_items] +@router.get("/shared/{token}/packing", response_model=list[TripPackingListItemRead]) +def read_shared_trip_packing_list( + session: SessionDep, + token: str, +) -> list[TripPackingListItemRead]: + p_items = session.exec( + select(TripPackingListItem).where( + TripPackingListItem.trip_id == _trip_from_token_or_404(session, token).trip_id + ) + ) + return [TripPackingListItemRead.serialize(i) for i in p_items] + + @router.post("/{trip_id}/packing", response_model=TripPackingListItemRead) def create_packing_item( session: SessionDep, @@ -435,11 +518,8 @@ def create_packing_item( data: TripPackingListItemCreate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripPackingListItemRead: - item = TripPackingListItem( - **data.model_dump(), - trip_id=trip_id, - user=current_user, - ) + _verify_trip_member(session, trip_id, current_user) + item = TripPackingListItem(**data.model_dump(), trip_id=trip_id) session.add(item) session.commit() session.refresh(item) @@ -454,11 +534,10 @@ def update_packing_item( p_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripPackingListItemRead: + _verify_trip_member(session, trip_id, current_user) db_item = session.exec( select(TripPackingListItem).where( - TripPackingListItem.id == p_id, - TripPackingListItem.trip_id == trip_id, - TripPackingListItem.user == current_user, + TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id ) ).one_or_none() @@ -482,11 +561,10 @@ def delete_packing_item( p_id: int, current_user: Annotated[str, Depends(get_current_username)], ): + _verify_trip_member(session, trip_id, current_user) item = session.exec( select(TripPackingListItem).where( - TripPackingListItem.id == p_id, - TripPackingListItem.trip_id == trip_id, - TripPackingListItem.user == current_user, + TripPackingListItem.id == p_id, TripPackingListItem.trip_id == trip_id ) ).one_or_none() @@ -504,11 +582,8 @@ def read_checklist( trip_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> list[TripChecklistItemRead]: - items = session.exec( - select(TripChecklistItem).where( - TripChecklistItem.trip_id == trip_id, TripChecklistItem.user == current_user - ) - ) + _verify_trip_member(session, trip_id, current_user) + items = session.exec(select(TripChecklistItem).where(TripChecklistItem.trip_id == trip_id)) return [TripChecklistItemRead.serialize(i) for i in items] @@ -532,11 +607,8 @@ def create_checklist_item( data: TripChecklistItemCreate, current_user: Annotated[str, Depends(get_current_username)], ) -> TripChecklistItemRead: - item = TripChecklistItem( - **data.model_dump(), - trip_id=trip_id, - user=current_user, - ) + _verify_trip_member(session, trip_id, current_user) + item = TripChecklistItem(**data.model_dump(), trip_id=trip_id) session.add(item) session.commit() session.refresh(item) @@ -551,12 +623,9 @@ def update_checklist_item( id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> TripChecklistItemRead: + _verify_trip_member(session, trip_id, current_user) db_item = session.exec( - select(TripChecklistItem).where( - TripChecklistItem.id == id, - TripChecklistItem.trip_id == trip_id, - TripChecklistItem.user == current_user, - ) + select(TripChecklistItem).where(TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id) ).one_or_none() if not db_item: @@ -579,11 +648,11 @@ def delete_checklist_item( id: int, current_user: Annotated[str, Depends(get_current_username)], ): + _verify_trip_member(session, trip_id, current_user) item = session.exec( select(TripChecklistItem).where( TripChecklistItem.id == id, TripChecklistItem.trip_id == trip_id, - TripChecklistItem.user == current_user, ) ).one_or_none() @@ -593,3 +662,122 @@ def delete_checklist_item( session.delete(item) session.commit() return {} + + +@router.get("/{trip_id}/members", response_model=list[TripMemberRead]) +def read_trip_members( + session: SessionDep, trip_id: int, current_user: Annotated[str, Depends(get_current_username)] +) -> list[TripMemberRead]: + _verify_trip_member(session, trip_id, current_user) + + members: list[TripMemberRead] = [] + owner = session.exec(select(Trip.user).where(Trip.id == trip_id)).first() + members.append(TripMemberRead(user=owner, invited_by=None, invited_at=None, joined_at=None)) + + db_members = session.exec(select(TripMember).where(TripMember.trip_id == trip_id)).all() + members.extend(TripMemberRead.serialize(m) for m in db_members) + return members + + +@router.post("/{trip_id}/members", response_model=TripMemberRead) +def invite_trip_member( + session: SessionDep, + trip_id: int, + data: TripMemberCreate, + current_user: Annotated[str, Depends(get_current_username)], +) -> TripMemberRead: + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) + + if db_trip.user == data.user: + raise HTTPException(status_code=409, detail="The resource already exists") + + exists = session.exec( + select(TripMember.id) + .where( + TripMember.trip_id == trip_id, + TripMember.user == data.user, + ) + .limit(1) + ).first() + if exists: + raise HTTPException(status_code=409, detail="The resource already exists") + + db_user = session.get(User, data.user) + print(data.user, db_user) + if not db_user: + raise HTTPException(status_code=404, detail="Not found") + + new_member = TripMember(trip_id=trip_id, user=data.user, invited_by=current_user) + session.add(new_member) + session.commit() + session.refresh(new_member) + return TripMemberRead.serialize(new_member) + + +@router.delete("/{trip_id}/members/{username}") +def delete_trip_member( + session: SessionDep, + trip_id: int, + username: str, + current_user: Annotated[str, Depends(get_current_username)], +): + print("yo") + db_trip = _get_trip_or_404(session, trip_id) + _verify_trip_member(session, trip_id, current_user) + + if current_user == db_trip.user and current_user == username: + raise HTTPException(status_code=400, detail="Bad request") + + if current_user != db_trip.user and current_user != username: + raise HTTPException(status_code=403, detail="Forbidden") + + member = session.exec( + select(TripMember).where( + TripMember.user == username, + TripMember.trip_id == trip_id, + ) + ).one_or_none() + if not member: + raise HTTPException(status_code=404, detail="Not found") + + session.delete(member) + session.commit() + return {} + + +@router.post("/{trip_id}/members/accept") +def accept_invite( + session: SessionDep, + trip_id: int, + current_user: Annotated[str, Depends(get_current_username)], +): + db_member = session.exec( + select(TripMember).where(TripMember.trip_id == trip_id, TripMember.user == current_user) + ).one_or_none() + if not db_member: + raise HTTPException(status_code=404, detail="Not found") + if db_member.joined_at: + raise HTTPException(status_code=409, detail="Already a member") + db_member.joined_at = utc_now() + session.add(db_member) + session.commit() + return {} + + +@router.post("/{trip_id}/members/decline") +def decline_invite( + session: SessionDep, + trip_id: int, + current_user: Annotated[str, Depends(get_current_username)], +): + db_member = session.exec( + select(TripMember).where(TripMember.trip_id == trip_id, TripMember.user == current_user) + ).one_or_none() + if not db_member: + raise HTTPException(status_code=404, detail="Not found") + if db_member.joined_at: + raise HTTPException(status_code=409, detail="Already a member") + session.delete(db_member) + session.commit() + return {} diff --git a/backend/trip/utils/utils.py b/backend/trip/utils/utils.py index 7d4ddb4..19f4c2e 100644 --- a/backend/trip/utils/utils.py +++ b/backend/trip/utils/utils.py @@ -1,5 +1,5 @@ import base64 -from datetime import date +from datetime import UTC, date, datetime from io import BytesIO from pathlib import Path from secrets import token_urlsafe @@ -48,6 +48,10 @@ def remove_image(path: str): raise Exception("Error deleting image:", exc, path) +def utc_now(): + return datetime.now(UTC) + + def parse_str_or_date_to_date(cdate: str | date) -> date: if isinstance(cdate, str): try: diff --git a/src/src/app/components/trip/trip.component.html b/src/src/app/components/trip/trip.component.html index 08aece1..64c7284 100644 --- a/src/src/app/components/trip/trip.component.html +++ b/src/src/app/components/trip/trip.component.html @@ -15,6 +15,7 @@
{{ trip.days || 0 }} {{ trip.days > 1 ? 'days' : 'day'}}
+ +{{ trip.days || 0 }} {{ trip.days > 1 ? 'days' : 'day'}}
+ @if (trip.collaborators.length) {-{{ trip.collaborators.length + 1 }} user{{ trip.collaborators.length > 0 + ? 's' + : '' }}
} +