⚡ Eager loading, merge code, early returns
This commit is contained in:
parent
bcfb736e9f
commit
5e7d6e6c4a
@ -17,17 +17,16 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
async def auth_params() -> AuthParams:
|
||||
data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
|
||||
|
||||
response = JSONResponse(content=data)
|
||||
if settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET:
|
||||
oidc_config = await get_oidc_config()
|
||||
auth_endpoint = oidc_config.get("authorization_endpoint")
|
||||
uri, state = get_oidc_client().create_authorization_url(auth_endpoint)
|
||||
data["oidc"] = uri
|
||||
if not (settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET):
|
||||
return {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
|
||||
|
||||
response = JSONResponse(content=data)
|
||||
response.set_cookie(
|
||||
"oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60
|
||||
)
|
||||
oidc_config = await get_oidc_config()
|
||||
auth_endpoint = oidc_config.get("authorization_endpoint")
|
||||
uri, state = get_oidc_client().create_authorization_url(auth_endpoint)
|
||||
data["oidc"] = uri
|
||||
|
||||
response = JSONResponse(content=data)
|
||||
response.set_cookie("oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60)
|
||||
|
||||
return response
|
||||
|
||||
@ -147,7 +146,7 @@ def refresh_token(refresh_token: str = Body(..., embed=True)):
|
||||
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username = payload.get("sub", None)
|
||||
|
||||
if username is None:
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail="Invalid Token")
|
||||
|
||||
new_access_token = create_access_token(data={"sub": username})
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import func, select
|
||||
|
||||
from ..config import settings
|
||||
from ..deps import SessionDep, get_current_username
|
||||
from ..models.models import (Category, CategoryCreate, CategoryRead,
|
||||
CategoryUpdate, Image)
|
||||
CategoryUpdate, Image, Place)
|
||||
from ..security import verify_exists_and_owns
|
||||
from ..utils.utils import b64img_decode, remove_image, save_image_to_file
|
||||
|
||||
@ -17,8 +18,10 @@ router = APIRouter(prefix="/api/categories", tags=["categories"])
|
||||
def read_categories(
|
||||
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
|
||||
) -> list[Category]:
|
||||
categories = session.exec(select(Category).filter(Category.user == current_user))
|
||||
return [CategoryRead.serialize(category) for category in categories]
|
||||
db_categories = session.exec(
|
||||
select(Category).options(selectinload(Category.image)).where(Category.user == current_user)
|
||||
).all()
|
||||
return [CategoryRead.serialize(category) for category in db_categories]
|
||||
|
||||
|
||||
@router.post("", response_model=CategoryRead)
|
||||
@ -48,7 +51,7 @@ def post_category(
|
||||
|
||||
|
||||
@router.put("/{category_id}", response_model=CategoryRead)
|
||||
def put_category(
|
||||
def update_category(
|
||||
session: SessionDep,
|
||||
category_id: int,
|
||||
category: CategoryUpdate,
|
||||
@ -104,7 +107,11 @@ def delete_category(
|
||||
db_category = session.get(Category, category_id)
|
||||
verify_exists_and_owns(current_user, db_category)
|
||||
|
||||
if get_category_placess_cnt(session, category_id, current_user) > 0:
|
||||
places_count = session.exec(
|
||||
select(func.count(Place.id)).where(Place.category_id == category_id, Place.user == current_user)
|
||||
).one()
|
||||
|
||||
if places_count > 0:
|
||||
raise HTTPException(status_code=409, detail="The resource is not orphan")
|
||||
|
||||
if db_category.image:
|
||||
@ -120,14 +127,3 @@ def delete_category(
|
||||
session.delete(db_category)
|
||||
session.commit()
|
||||
return {}
|
||||
|
||||
|
||||
@router.get("/{category_id}/count")
|
||||
def get_category_placess_cnt(
|
||||
session: SessionDep,
|
||||
category_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> int:
|
||||
db_category = session.get(Category, category_id)
|
||||
verify_exists_and_owns(current_user, db_category)
|
||||
return len(db_category.places)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from ..config import settings
|
||||
@ -18,8 +19,12 @@ router = APIRouter(prefix="/api/places", tags=["places"])
|
||||
def read_places(
|
||||
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
|
||||
) -> list[PlaceRead]:
|
||||
places = session.exec(select(Place).filter(Place.user == current_user))
|
||||
return [PlaceRead.serialize(p) for p in places]
|
||||
db_places = session.exec(
|
||||
select(Place)
|
||||
.options(selectinload(Place.image), selectinload(Place.category))
|
||||
.where(Place.user == current_user)
|
||||
).all()
|
||||
return [PlaceRead.serialize(p) for p in db_places]
|
||||
|
||||
|
||||
@router.post("", response_model=PlaceRead)
|
||||
@ -70,7 +75,7 @@ async def create_places(
|
||||
for place in places:
|
||||
category_name = place.category
|
||||
category = session.exec(
|
||||
select(Category).filter(Category.user == current_user, Category.name == category_name)
|
||||
select(Category).where(Category.user == current_user, Category.name == category_name)
|
||||
).first()
|
||||
if not category:
|
||||
continue
|
||||
@ -132,16 +137,18 @@ def update_place(
|
||||
session.commit()
|
||||
session.refresh(image)
|
||||
|
||||
place_data["image_id"] = image.id
|
||||
|
||||
if db_place.image_id:
|
||||
old_image = session.get(Image, db_place.image_id)
|
||||
try:
|
||||
remove_image(old_image.filename)
|
||||
session.delete(old_image)
|
||||
db_place.image_id = None
|
||||
session.refresh(db_place)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
|
||||
db_place.image_id = image.id
|
||||
|
||||
for key, value in place_data.items():
|
||||
setattr(db_place, key, value)
|
||||
|
||||
@ -179,7 +186,11 @@ def get_place(
|
||||
place_id: int,
|
||||
current_user: Annotated[str, Depends(get_current_username)],
|
||||
) -> PlaceRead:
|
||||
db_place = session.get(Place, place_id)
|
||||
db_place = session.exec(
|
||||
select(Place)
|
||||
.options(selectinload(Place.image), selectinload(Place.category))
|
||||
.where(Place.id == place_id)
|
||||
).first()
|
||||
verify_exists_and_owns(current_user, db_place)
|
||||
|
||||
return PlaceRead.serialize(db_place, exclude_gpx=False)
|
||||
|
||||
@ -117,7 +117,7 @@ async def import_data(
|
||||
|
||||
existing_categories = {
|
||||
category.name: category
|
||||
for category in session.exec(select(Category).filter(Category.user == current_user)).all()
|
||||
for category in session.exec(select(Category).where(Category.user == current_user)).all()
|
||||
}
|
||||
|
||||
categories_to_add = []
|
||||
@ -142,6 +142,7 @@ async def import_data(
|
||||
image = Image(filename=filename, user=current_user)
|
||||
session.add(image)
|
||||
session.flush()
|
||||
session.refresh(image)
|
||||
|
||||
if category_exists.image_id:
|
||||
old_image = session.get(Image, category_exists.image_id)
|
||||
@ -177,6 +178,7 @@ async def import_data(
|
||||
image = Image(filename=filename, user=current_user)
|
||||
session.add(image)
|
||||
session.flush()
|
||||
session.refresh(image)
|
||||
category_data["image_id"] = image.id
|
||||
|
||||
new_category = Category(**category_data)
|
||||
@ -213,6 +215,7 @@ async def import_data(
|
||||
image = Image(filename=filename, user=current_user)
|
||||
session.add(image)
|
||||
session.flush()
|
||||
session.refresh(image)
|
||||
place_data["image_id"] = image.id
|
||||
|
||||
new_place = Place(**place_data)
|
||||
@ -267,6 +270,7 @@ async def import_data(
|
||||
image = Image(filename=filename, user=current_user)
|
||||
session.add(image)
|
||||
session.flush()
|
||||
session.refresh(image)
|
||||
trip_data["image_id"] = image.id
|
||||
|
||||
new_trip = Trip(**trip_data)
|
||||
@ -315,6 +319,7 @@ async def import_data(
|
||||
image = Image(filename=filename, user=current_user)
|
||||
session.add(image)
|
||||
session.flush()
|
||||
session.refresh(image)
|
||||
item_data["image_id"] = image.id
|
||||
|
||||
trip_item = TripItem(**item_data)
|
||||
@ -330,7 +335,7 @@ async def import_data(
|
||||
"categories": [
|
||||
CategoryRead.serialize(c)
|
||||
for c in session.exec(
|
||||
select(Category).options(selectinload(Category.image)).filter(Category.user == current_user)
|
||||
select(Category).options(selectinload(Category.image)).where(Category.user == current_user)
|
||||
).all()
|
||||
],
|
||||
"settings": UserRead.serialize(session.get(User, current_user)),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user