From 5e7d6e6c4a360442bc28bfb8d5a3e780f02e74ab Mon Sep 17 00:00:00 2001 From: itskovacs Date: Sat, 11 Oct 2025 17:58:33 +0200 Subject: [PATCH] :zap: Eager loading, merge code, early returns --- backend/trip/routers/auth.py | 21 ++++++++++----------- backend/trip/routers/categories.py | 30 +++++++++++++----------------- backend/trip/routers/places.py | 23 +++++++++++++++++------ backend/trip/routers/settings.py | 9 +++++++-- 4 files changed, 47 insertions(+), 36 deletions(-) diff --git a/backend/trip/routers/auth.py b/backend/trip/routers/auth.py index cfb153b..4668ba6 100644 --- a/backend/trip/routers/auth.py +++ b/backend/trip/routers/auth.py @@ -17,17 +17,16 @@ router = APIRouter(prefix="/api/auth", tags=["auth"]) async def auth_params() -> AuthParams: data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE} - response = JSONResponse(content=data) - if settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET: - oidc_config = await get_oidc_config() - auth_endpoint = oidc_config.get("authorization_endpoint") - uri, state = get_oidc_client().create_authorization_url(auth_endpoint) - data["oidc"] = uri + if not (settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET): + return {"oidc": None, "register_enabled": settings.REGISTER_ENABLE} - response = JSONResponse(content=data) - response.set_cookie( - "oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60 - ) + oidc_config = await get_oidc_config() + auth_endpoint = oidc_config.get("authorization_endpoint") + uri, state = get_oidc_client().create_authorization_url(auth_endpoint) + data["oidc"] = uri + + response = JSONResponse(content=data) + response.set_cookie("oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60) return response @@ -147,7 +146,7 @@ def refresh_token(refresh_token: str = Body(..., embed=True)): payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) username = payload.get("sub", None) - if username is None: + if not username: raise HTTPException(status_code=401, detail="Invalid Token") new_access_token = create_access_token(data={"sub": username}) diff --git a/backend/trip/routers/categories.py b/backend/trip/routers/categories.py index 40d59ca..dd41b3a 100644 --- a/backend/trip/routers/categories.py +++ b/backend/trip/routers/categories.py @@ -1,12 +1,13 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException -from sqlmodel import select +from sqlalchemy.orm import selectinload +from sqlmodel import func, select from ..config import settings from ..deps import SessionDep, get_current_username from ..models.models import (Category, CategoryCreate, CategoryRead, - CategoryUpdate, Image) + CategoryUpdate, Image, Place) from ..security import verify_exists_and_owns from ..utils.utils import b64img_decode, remove_image, save_image_to_file @@ -17,8 +18,10 @@ router = APIRouter(prefix="/api/categories", tags=["categories"]) def read_categories( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)] ) -> list[Category]: - categories = session.exec(select(Category).filter(Category.user == current_user)) - return [CategoryRead.serialize(category) for category in categories] + db_categories = session.exec( + select(Category).options(selectinload(Category.image)).where(Category.user == current_user) + ).all() + return [CategoryRead.serialize(category) for category in db_categories] @router.post("", response_model=CategoryRead) @@ -48,7 +51,7 @@ def post_category( @router.put("/{category_id}", response_model=CategoryRead) -def put_category( +def update_category( session: SessionDep, category_id: int, category: CategoryUpdate, @@ -104,7 +107,11 @@ def delete_category( db_category = session.get(Category, category_id) verify_exists_and_owns(current_user, db_category) - if get_category_placess_cnt(session, category_id, current_user) > 0: + places_count = session.exec( + select(func.count(Place.id)).where(Place.category_id == category_id, Place.user == current_user) + ).one() + + if places_count > 0: raise HTTPException(status_code=409, detail="The resource is not orphan") if db_category.image: @@ -120,14 +127,3 @@ def delete_category( session.delete(db_category) session.commit() return {} - - -@router.get("/{category_id}/count") -def get_category_placess_cnt( - session: SessionDep, - category_id: int, - current_user: Annotated[str, Depends(get_current_username)], -) -> int: - db_category = session.get(Category, category_id) - verify_exists_and_owns(current_user, db_category) - return len(db_category.places) diff --git a/backend/trip/routers/places.py b/backend/trip/routers/places.py index 1913a3d..f5d4656 100644 --- a/backend/trip/routers/places.py +++ b/backend/trip/routers/places.py @@ -1,6 +1,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import selectinload from sqlmodel import select from ..config import settings @@ -18,8 +19,12 @@ router = APIRouter(prefix="/api/places", tags=["places"]) def read_places( session: SessionDep, current_user: Annotated[str, Depends(get_current_username)] ) -> list[PlaceRead]: - places = session.exec(select(Place).filter(Place.user == current_user)) - return [PlaceRead.serialize(p) for p in places] + db_places = session.exec( + select(Place) + .options(selectinload(Place.image), selectinload(Place.category)) + .where(Place.user == current_user) + ).all() + return [PlaceRead.serialize(p) for p in db_places] @router.post("", response_model=PlaceRead) @@ -70,7 +75,7 @@ async def create_places( for place in places: category_name = place.category category = session.exec( - select(Category).filter(Category.user == current_user, Category.name == category_name) + select(Category).where(Category.user == current_user, Category.name == category_name) ).first() if not category: continue @@ -132,16 +137,18 @@ def update_place( session.commit() session.refresh(image) - place_data["image_id"] = image.id - if db_place.image_id: old_image = session.get(Image, db_place.image_id) try: remove_image(old_image.filename) session.delete(old_image) + db_place.image_id = None + session.refresh(db_place) except Exception: raise HTTPException(status_code=400, detail="Bad request") + db_place.image_id = image.id + for key, value in place_data.items(): setattr(db_place, key, value) @@ -179,7 +186,11 @@ def get_place( place_id: int, current_user: Annotated[str, Depends(get_current_username)], ) -> PlaceRead: - db_place = session.get(Place, place_id) + db_place = session.exec( + select(Place) + .options(selectinload(Place.image), selectinload(Place.category)) + .where(Place.id == place_id) + ).first() verify_exists_and_owns(current_user, db_place) return PlaceRead.serialize(db_place, exclude_gpx=False) diff --git a/backend/trip/routers/settings.py b/backend/trip/routers/settings.py index 8d89ecb..bfdb4da 100644 --- a/backend/trip/routers/settings.py +++ b/backend/trip/routers/settings.py @@ -117,7 +117,7 @@ async def import_data( existing_categories = { category.name: category - for category in session.exec(select(Category).filter(Category.user == current_user)).all() + for category in session.exec(select(Category).where(Category.user == current_user)).all() } categories_to_add = [] @@ -142,6 +142,7 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() + session.refresh(image) if category_exists.image_id: old_image = session.get(Image, category_exists.image_id) @@ -177,6 +178,7 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() + session.refresh(image) category_data["image_id"] = image.id new_category = Category(**category_data) @@ -213,6 +215,7 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() + session.refresh(image) place_data["image_id"] = image.id new_place = Place(**place_data) @@ -267,6 +270,7 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() + session.refresh(image) trip_data["image_id"] = image.id new_trip = Trip(**trip_data) @@ -315,6 +319,7 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() + session.refresh(image) item_data["image_id"] = image.id trip_item = TripItem(**item_data) @@ -330,7 +335,7 @@ async def import_data( "categories": [ CategoryRead.serialize(c) for c in session.exec( - select(Category).options(selectinload(Category.image)).filter(Category.user == current_user) + select(Category).options(selectinload(Category.image)).where(Category.user == current_user) ).all() ], "settings": UserRead.serialize(session.get(User, current_user)),