Eager loading, merge code, early returns

This commit is contained in:
itskovacs 2025-10-11 17:58:33 +02:00
parent bcfb736e9f
commit 5e7d6e6c4a
4 changed files with 47 additions and 36 deletions

View File

@ -17,17 +17,16 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
async def auth_params() -> AuthParams:
data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
response = JSONResponse(content=data)
if settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET:
oidc_config = await get_oidc_config()
auth_endpoint = oidc_config.get("authorization_endpoint")
uri, state = get_oidc_client().create_authorization_url(auth_endpoint)
data["oidc"] = uri
if not (settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET):
return {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
response = JSONResponse(content=data)
response.set_cookie(
"oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60
)
oidc_config = await get_oidc_config()
auth_endpoint = oidc_config.get("authorization_endpoint")
uri, state = get_oidc_client().create_authorization_url(auth_endpoint)
data["oidc"] = uri
response = JSONResponse(content=data)
response.set_cookie("oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60)
return response
@ -147,7 +146,7 @@ def refresh_token(refresh_token: str = Body(..., embed=True)):
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username = payload.get("sub", None)
if username is None:
if not username:
raise HTTPException(status_code=401, detail="Invalid Token")
new_access_token = create_access_token(data={"sub": username})

View File

@ -1,12 +1,13 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import select
from sqlalchemy.orm import selectinload
from sqlmodel import func, select
from ..config import settings
from ..deps import SessionDep, get_current_username
from ..models.models import (Category, CategoryCreate, CategoryRead,
CategoryUpdate, Image)
CategoryUpdate, Image, Place)
from ..security import verify_exists_and_owns
from ..utils.utils import b64img_decode, remove_image, save_image_to_file
@ -17,8 +18,10 @@ router = APIRouter(prefix="/api/categories", tags=["categories"])
def read_categories(
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
) -> list[Category]:
categories = session.exec(select(Category).filter(Category.user == current_user))
return [CategoryRead.serialize(category) for category in categories]
db_categories = session.exec(
select(Category).options(selectinload(Category.image)).where(Category.user == current_user)
).all()
return [CategoryRead.serialize(category) for category in db_categories]
@router.post("", response_model=CategoryRead)
@ -48,7 +51,7 @@ def post_category(
@router.put("/{category_id}", response_model=CategoryRead)
def put_category(
def update_category(
session: SessionDep,
category_id: int,
category: CategoryUpdate,
@ -104,7 +107,11 @@ def delete_category(
db_category = session.get(Category, category_id)
verify_exists_and_owns(current_user, db_category)
if get_category_placess_cnt(session, category_id, current_user) > 0:
places_count = session.exec(
select(func.count(Place.id)).where(Place.category_id == category_id, Place.user == current_user)
).one()
if places_count > 0:
raise HTTPException(status_code=409, detail="The resource is not orphan")
if db_category.image:
@ -120,14 +127,3 @@ def delete_category(
session.delete(db_category)
session.commit()
return {}
@router.get("/{category_id}/count")
def get_category_placess_cnt(
session: SessionDep,
category_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> int:
db_category = session.get(Category, category_id)
verify_exists_and_owns(current_user, db_category)
return len(db_category.places)

View File

@ -1,6 +1,7 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import selectinload
from sqlmodel import select
from ..config import settings
@ -18,8 +19,12 @@ router = APIRouter(prefix="/api/places", tags=["places"])
def read_places(
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
) -> list[PlaceRead]:
places = session.exec(select(Place).filter(Place.user == current_user))
return [PlaceRead.serialize(p) for p in places]
db_places = session.exec(
select(Place)
.options(selectinload(Place.image), selectinload(Place.category))
.where(Place.user == current_user)
).all()
return [PlaceRead.serialize(p) for p in db_places]
@router.post("", response_model=PlaceRead)
@ -70,7 +75,7 @@ async def create_places(
for place in places:
category_name = place.category
category = session.exec(
select(Category).filter(Category.user == current_user, Category.name == category_name)
select(Category).where(Category.user == current_user, Category.name == category_name)
).first()
if not category:
continue
@ -132,16 +137,18 @@ def update_place(
session.commit()
session.refresh(image)
place_data["image_id"] = image.id
if db_place.image_id:
old_image = session.get(Image, db_place.image_id)
try:
remove_image(old_image.filename)
session.delete(old_image)
db_place.image_id = None
session.refresh(db_place)
except Exception:
raise HTTPException(status_code=400, detail="Bad request")
db_place.image_id = image.id
for key, value in place_data.items():
setattr(db_place, key, value)
@ -179,7 +186,11 @@ def get_place(
place_id: int,
current_user: Annotated[str, Depends(get_current_username)],
) -> PlaceRead:
db_place = session.get(Place, place_id)
db_place = session.exec(
select(Place)
.options(selectinload(Place.image), selectinload(Place.category))
.where(Place.id == place_id)
).first()
verify_exists_and_owns(current_user, db_place)
return PlaceRead.serialize(db_place, exclude_gpx=False)

View File

@ -117,7 +117,7 @@ async def import_data(
existing_categories = {
category.name: category
for category in session.exec(select(Category).filter(Category.user == current_user)).all()
for category in session.exec(select(Category).where(Category.user == current_user)).all()
}
categories_to_add = []
@ -142,6 +142,7 @@ async def import_data(
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
if category_exists.image_id:
old_image = session.get(Image, category_exists.image_id)
@ -177,6 +178,7 @@ async def import_data(
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
category_data["image_id"] = image.id
new_category = Category(**category_data)
@ -213,6 +215,7 @@ async def import_data(
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
place_data["image_id"] = image.id
new_place = Place(**place_data)
@ -267,6 +270,7 @@ async def import_data(
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
trip_data["image_id"] = image.id
new_trip = Trip(**trip_data)
@ -315,6 +319,7 @@ async def import_data(
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
session.refresh(image)
item_data["image_id"] = image.id
trip_item = TripItem(**item_data)
@ -330,7 +335,7 @@ async def import_data(
"categories": [
CategoryRead.serialize(c)
for c in session.exec(
select(Category).options(selectinload(Category.image)).filter(Category.user == current_user)
select(Category).options(selectinload(Category.image)).where(Category.user == current_user)
).all()
],
"settings": UserRead.serialize(session.get(User, current_user)),