import jwt from fastapi import APIRouter, Body, HTTPException from ..config import settings from ..db.core import init_user_data from ..deps import SessionDep from ..models.models import AuthParams, LoginRegisterModel, Token, User from ..security import (create_access_token, create_tokens, get_oidc_client, get_oidc_config, hash_password, verify_password) from ..utils.utils import generate_filename router = APIRouter(prefix="/api/auth", tags=["auth"]) @router.get("/params", response_model=AuthParams) async def auth_params() -> AuthParams: data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE} if settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET: oidc_config = await get_oidc_config() auth_endpoint = oidc_config.get("authorization_endpoint") data["oidc"] = ( f"{auth_endpoint}?client_id={settings.OIDC_CLIENT_ID}&redirect_uri={settings.OIDC_REDIRECT_URI}&response_type=code&scope=openid+profile" ) return data @router.post("/oidc/login", response_model=Token) async def oidc_login(session: SessionDep, code: str = Body(..., embed=True)) -> Token: if not (settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET): raise HTTPException(status_code=400, detail="Partial OIDC config") oidc_config = await get_oidc_config() token_endpoint = oidc_config.get("token_endpoint") try: oidc_client = get_oidc_client() token = oidc_client.fetch_token( token_endpoint, grant_type="authorization_code", code=code, ) except Exception: raise HTTPException(status_code=401, detail="OIDC login failed") id_token = token.get("id_token") alg = jwt.get_unverified_header(id_token).get("alg") match alg: case "HS256": decoded = jwt.decode( id_token, settings.OIDC_CLIENT_SECRET, algorithms=["HS256"], audience=settings.OIDC_CLIENT_ID, ) case "RS256": jwks_uri = oidc_config.get("jwks_uri") issuer = oidc_config.get("issuer") jwks_client = jwt.PyJWKClient(jwks_uri) try: signing_key = jwks_client.get_signing_key_from_jwt(id_token) decoded = jwt.decode( id_token, key=signing_key.key, algorithms=["RS256"], audience=settings.OIDC_CLIENT_ID, issuer=issuer, ) except Exception: raise HTTPException(status_code=401, detail="Invalid ID token") case _: raise HTTPException(status_code=500, detail="OIDC login failed, algorithm not handled") if not decoded: raise HTTPException(status_code=401, detail="Invalid ID token") username = decoded.get("preferred_username") if not username: raise HTTPException(status_code=401, detail="OIDC login failed, preferred_username missing") user = session.get(User, username) if not user: # TODO: password is non-null, we must init the pw with something, the model is not made for OIDC user = User(username=username, password=hash_password(generate_filename("find-something-else"))) session.add(user) session.commit() init_user_data(session, username) return create_tokens(data={"sub": username}) @router.post("/login", response_model=Token) def login(req: LoginRegisterModel, session: SessionDep) -> Token: if settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET: raise HTTPException(status_code=400, detail="OIDC is configured") db_user = session.get(User, req.username) if not db_user or not verify_password(req.password, db_user.password): raise HTTPException(status_code=401, detail="Invalid credentials") return create_tokens(data={"sub": db_user.username}) @router.post("/register", response_model=Token) def register(req: LoginRegisterModel, session: SessionDep) -> Token: if not settings.REGISTER_ENABLE: raise HTTPException(status_code=400, detail="Registration disabled") if settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET: raise HTTPException(status_code=400, detail="OIDC is configured") db_user = session.get(User, req.username) if db_user: raise HTTPException(status_code=409, detail="The resource already exists") new_user = User(username=req.username, password=hash_password(req.password)) session.add(new_user) session.commit() init_user_data(session, new_user.username) return create_tokens(data={"sub": new_user.username}) @router.post("/refresh") def refresh_token(refresh_token: str = Body(..., embed=True)): if not refresh_token: raise HTTPException(status_code=400, detail="Refresh token expected") try: payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) username = payload.get("sub", None) if username is None: raise HTTPException(status_code=401, detail="Invalid Token") new_access_token = create_access_token(data={"sub": username}) return {"access_token": new_access_token} except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Invalid Token") except jwt.PyJWTError: raise HTTPException(status_code=401, detail="Invalid Token")