⚡ Eager loading, merge code, early returns
This commit is contained in:
parent
bcfb736e9f
commit
5e7d6e6c4a
@ -17,17 +17,16 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|||||||
async def auth_params() -> AuthParams:
|
async def auth_params() -> AuthParams:
|
||||||
data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
|
data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
|
||||||
|
|
||||||
response = JSONResponse(content=data)
|
if not (settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET):
|
||||||
if settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET:
|
return {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
|
||||||
oidc_config = await get_oidc_config()
|
|
||||||
auth_endpoint = oidc_config.get("authorization_endpoint")
|
|
||||||
uri, state = get_oidc_client().create_authorization_url(auth_endpoint)
|
|
||||||
data["oidc"] = uri
|
|
||||||
|
|
||||||
response = JSONResponse(content=data)
|
oidc_config = await get_oidc_config()
|
||||||
response.set_cookie(
|
auth_endpoint = oidc_config.get("authorization_endpoint")
|
||||||
"oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60
|
uri, state = get_oidc_client().create_authorization_url(auth_endpoint)
|
||||||
)
|
data["oidc"] = uri
|
||||||
|
|
||||||
|
response = JSONResponse(content=data)
|
||||||
|
response.set_cookie("oidc_state", value=state, httponly=True, secure=True, samesite="Lax", max_age=60)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -147,7 +146,7 @@ def refresh_token(refresh_token: str = Body(..., embed=True)):
|
|||||||
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||||
username = payload.get("sub", None)
|
username = payload.get("sub", None)
|
||||||
|
|
||||||
if username is None:
|
if not username:
|
||||||
raise HTTPException(status_code=401, detail="Invalid Token")
|
raise HTTPException(status_code=401, detail="Invalid Token")
|
||||||
|
|
||||||
new_access_token = create_access_token(data={"sub": username})
|
new_access_token = create_access_token(data={"sub": username})
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from sqlmodel import select
|
from sqlalchemy.orm import selectinload
|
||||||
|
from sqlmodel import func, select
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
from ..deps import SessionDep, get_current_username
|
from ..deps import SessionDep, get_current_username
|
||||||
from ..models.models import (Category, CategoryCreate, CategoryRead,
|
from ..models.models import (Category, CategoryCreate, CategoryRead,
|
||||||
CategoryUpdate, Image)
|
CategoryUpdate, Image, Place)
|
||||||
from ..security import verify_exists_and_owns
|
from ..security import verify_exists_and_owns
|
||||||
from ..utils.utils import b64img_decode, remove_image, save_image_to_file
|
from ..utils.utils import b64img_decode, remove_image, save_image_to_file
|
||||||
|
|
||||||
@ -17,8 +18,10 @@ router = APIRouter(prefix="/api/categories", tags=["categories"])
|
|||||||
def read_categories(
|
def read_categories(
|
||||||
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
|
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
|
||||||
) -> list[Category]:
|
) -> list[Category]:
|
||||||
categories = session.exec(select(Category).filter(Category.user == current_user))
|
db_categories = session.exec(
|
||||||
return [CategoryRead.serialize(category) for category in categories]
|
select(Category).options(selectinload(Category.image)).where(Category.user == current_user)
|
||||||
|
).all()
|
||||||
|
return [CategoryRead.serialize(category) for category in db_categories]
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=CategoryRead)
|
@router.post("", response_model=CategoryRead)
|
||||||
@ -48,7 +51,7 @@ def post_category(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/{category_id}", response_model=CategoryRead)
|
@router.put("/{category_id}", response_model=CategoryRead)
|
||||||
def put_category(
|
def update_category(
|
||||||
session: SessionDep,
|
session: SessionDep,
|
||||||
category_id: int,
|
category_id: int,
|
||||||
category: CategoryUpdate,
|
category: CategoryUpdate,
|
||||||
@ -104,7 +107,11 @@ def delete_category(
|
|||||||
db_category = session.get(Category, category_id)
|
db_category = session.get(Category, category_id)
|
||||||
verify_exists_and_owns(current_user, db_category)
|
verify_exists_and_owns(current_user, db_category)
|
||||||
|
|
||||||
if get_category_placess_cnt(session, category_id, current_user) > 0:
|
places_count = session.exec(
|
||||||
|
select(func.count(Place.id)).where(Place.category_id == category_id, Place.user == current_user)
|
||||||
|
).one()
|
||||||
|
|
||||||
|
if places_count > 0:
|
||||||
raise HTTPException(status_code=409, detail="The resource is not orphan")
|
raise HTTPException(status_code=409, detail="The resource is not orphan")
|
||||||
|
|
||||||
if db_category.image:
|
if db_category.image:
|
||||||
@ -120,14 +127,3 @@ def delete_category(
|
|||||||
session.delete(db_category)
|
session.delete(db_category)
|
||||||
session.commit()
|
session.commit()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{category_id}/count")
|
|
||||||
def get_category_placess_cnt(
|
|
||||||
session: SessionDep,
|
|
||||||
category_id: int,
|
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
|
||||||
) -> int:
|
|
||||||
db_category = session.get(Category, category_id)
|
|
||||||
verify_exists_and_owns(current_user, db_category)
|
|
||||||
return len(db_category.places)
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
@ -18,8 +19,12 @@ router = APIRouter(prefix="/api/places", tags=["places"])
|
|||||||
def read_places(
|
def read_places(
|
||||||
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
|
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
|
||||||
) -> list[PlaceRead]:
|
) -> list[PlaceRead]:
|
||||||
places = session.exec(select(Place).filter(Place.user == current_user))
|
db_places = session.exec(
|
||||||
return [PlaceRead.serialize(p) for p in places]
|
select(Place)
|
||||||
|
.options(selectinload(Place.image), selectinload(Place.category))
|
||||||
|
.where(Place.user == current_user)
|
||||||
|
).all()
|
||||||
|
return [PlaceRead.serialize(p) for p in db_places]
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=PlaceRead)
|
@router.post("", response_model=PlaceRead)
|
||||||
@ -70,7 +75,7 @@ async def create_places(
|
|||||||
for place in places:
|
for place in places:
|
||||||
category_name = place.category
|
category_name = place.category
|
||||||
category = session.exec(
|
category = session.exec(
|
||||||
select(Category).filter(Category.user == current_user, Category.name == category_name)
|
select(Category).where(Category.user == current_user, Category.name == category_name)
|
||||||
).first()
|
).first()
|
||||||
if not category:
|
if not category:
|
||||||
continue
|
continue
|
||||||
@ -132,16 +137,18 @@ def update_place(
|
|||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(image)
|
session.refresh(image)
|
||||||
|
|
||||||
place_data["image_id"] = image.id
|
|
||||||
|
|
||||||
if db_place.image_id:
|
if db_place.image_id:
|
||||||
old_image = session.get(Image, db_place.image_id)
|
old_image = session.get(Image, db_place.image_id)
|
||||||
try:
|
try:
|
||||||
remove_image(old_image.filename)
|
remove_image(old_image.filename)
|
||||||
session.delete(old_image)
|
session.delete(old_image)
|
||||||
|
db_place.image_id = None
|
||||||
|
session.refresh(db_place)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=400, detail="Bad request")
|
raise HTTPException(status_code=400, detail="Bad request")
|
||||||
|
|
||||||
|
db_place.image_id = image.id
|
||||||
|
|
||||||
for key, value in place_data.items():
|
for key, value in place_data.items():
|
||||||
setattr(db_place, key, value)
|
setattr(db_place, key, value)
|
||||||
|
|
||||||
@ -179,7 +186,11 @@ def get_place(
|
|||||||
place_id: int,
|
place_id: int,
|
||||||
current_user: Annotated[str, Depends(get_current_username)],
|
current_user: Annotated[str, Depends(get_current_username)],
|
||||||
) -> PlaceRead:
|
) -> PlaceRead:
|
||||||
db_place = session.get(Place, place_id)
|
db_place = session.exec(
|
||||||
|
select(Place)
|
||||||
|
.options(selectinload(Place.image), selectinload(Place.category))
|
||||||
|
.where(Place.id == place_id)
|
||||||
|
).first()
|
||||||
verify_exists_and_owns(current_user, db_place)
|
verify_exists_and_owns(current_user, db_place)
|
||||||
|
|
||||||
return PlaceRead.serialize(db_place, exclude_gpx=False)
|
return PlaceRead.serialize(db_place, exclude_gpx=False)
|
||||||
|
|||||||
@ -117,7 +117,7 @@ async def import_data(
|
|||||||
|
|
||||||
existing_categories = {
|
existing_categories = {
|
||||||
category.name: category
|
category.name: category
|
||||||
for category in session.exec(select(Category).filter(Category.user == current_user)).all()
|
for category in session.exec(select(Category).where(Category.user == current_user)).all()
|
||||||
}
|
}
|
||||||
|
|
||||||
categories_to_add = []
|
categories_to_add = []
|
||||||
@ -142,6 +142,7 @@ async def import_data(
|
|||||||
image = Image(filename=filename, user=current_user)
|
image = Image(filename=filename, user=current_user)
|
||||||
session.add(image)
|
session.add(image)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
session.refresh(image)
|
||||||
|
|
||||||
if category_exists.image_id:
|
if category_exists.image_id:
|
||||||
old_image = session.get(Image, category_exists.image_id)
|
old_image = session.get(Image, category_exists.image_id)
|
||||||
@ -177,6 +178,7 @@ async def import_data(
|
|||||||
image = Image(filename=filename, user=current_user)
|
image = Image(filename=filename, user=current_user)
|
||||||
session.add(image)
|
session.add(image)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
session.refresh(image)
|
||||||
category_data["image_id"] = image.id
|
category_data["image_id"] = image.id
|
||||||
|
|
||||||
new_category = Category(**category_data)
|
new_category = Category(**category_data)
|
||||||
@ -213,6 +215,7 @@ async def import_data(
|
|||||||
image = Image(filename=filename, user=current_user)
|
image = Image(filename=filename, user=current_user)
|
||||||
session.add(image)
|
session.add(image)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
session.refresh(image)
|
||||||
place_data["image_id"] = image.id
|
place_data["image_id"] = image.id
|
||||||
|
|
||||||
new_place = Place(**place_data)
|
new_place = Place(**place_data)
|
||||||
@ -267,6 +270,7 @@ async def import_data(
|
|||||||
image = Image(filename=filename, user=current_user)
|
image = Image(filename=filename, user=current_user)
|
||||||
session.add(image)
|
session.add(image)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
session.refresh(image)
|
||||||
trip_data["image_id"] = image.id
|
trip_data["image_id"] = image.id
|
||||||
|
|
||||||
new_trip = Trip(**trip_data)
|
new_trip = Trip(**trip_data)
|
||||||
@ -315,6 +319,7 @@ async def import_data(
|
|||||||
image = Image(filename=filename, user=current_user)
|
image = Image(filename=filename, user=current_user)
|
||||||
session.add(image)
|
session.add(image)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
session.refresh(image)
|
||||||
item_data["image_id"] = image.id
|
item_data["image_id"] = image.id
|
||||||
|
|
||||||
trip_item = TripItem(**item_data)
|
trip_item = TripItem(**item_data)
|
||||||
@ -330,7 +335,7 @@ async def import_data(
|
|||||||
"categories": [
|
"categories": [
|
||||||
CategoryRead.serialize(c)
|
CategoryRead.serialize(c)
|
||||||
for c in session.exec(
|
for c in session.exec(
|
||||||
select(Category).options(selectinload(Category.image)).filter(Category.user == current_user)
|
select(Category).options(selectinload(Category.image)).where(Category.user == current_user)
|
||||||
).all()
|
).all()
|
||||||
],
|
],
|
||||||
"settings": UserRead.serialize(session.get(User, current_user)),
|
"settings": UserRead.serialize(session.get(User, current_user)),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user