2025-07-25 18:01:03 +02:00

147 lines
5.4 KiB
Python

import jwt
from fastapi import APIRouter, Body, HTTPException
from ..config import settings
from ..db.core import init_user_data
from ..deps import SessionDep
from ..models.models import AuthParams, LoginRegisterModel, Token, User
from ..security import (create_access_token, create_tokens, get_oidc_client,
get_oidc_config, hash_password, verify_password)
from ..utils.utils import generate_filename
router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.get("/params", response_model=AuthParams)
async def auth_params() -> AuthParams:
data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE}
if settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET:
oidc_config = await get_oidc_config()
auth_endpoint = oidc_config.get("authorization_endpoint")
data["oidc"] = (
f"{auth_endpoint}?client_id={settings.OIDC_CLIENT_ID}&redirect_uri={settings.OIDC_REDIRECT_URI}&response_type=code&scope=openid+profile"
)
return data
@router.post("/oidc/login", response_model=Token)
async def oidc_login(session: SessionDep, code: str = Body(..., embed=True)) -> Token:
if not (settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET):
raise HTTPException(status_code=400, detail="Partial OIDC config")
oidc_config = await get_oidc_config()
token_endpoint = oidc_config.get("token_endpoint")
try:
oidc_client = get_oidc_client()
token = oidc_client.fetch_token(
token_endpoint,
grant_type="authorization_code",
code=code,
)
except Exception:
raise HTTPException(status_code=401, detail="OIDC login failed")
id_token = token.get("id_token")
alg = jwt.get_unverified_header(id_token).get("alg")
match alg:
case "HS256":
decoded = jwt.decode(
id_token,
settings.OIDC_CLIENT_SECRET,
algorithms=["HS256"],
audience=settings.OIDC_CLIENT_ID,
)
case "RS256":
jwks_uri = oidc_config.get("jwks_uri")
issuer = oidc_config.get("issuer")
jwks_client = jwt.PyJWKClient(jwks_uri)
try:
signing_key = jwks_client.get_signing_key_from_jwt(id_token)
decoded = jwt.decode(
id_token,
key=signing_key.key,
algorithms=["RS256"],
audience=settings.OIDC_CLIENT_ID,
issuer=issuer,
)
except Exception:
raise HTTPException(status_code=401, detail="Invalid ID token")
case _:
raise HTTPException(status_code=500, detail="OIDC login failed, algorithm not handled")
if not decoded:
raise HTTPException(status_code=401, detail="Invalid ID token")
username = decoded.get("preferred_username")
if not username:
raise HTTPException(status_code=401, detail="OIDC login failed, preferred_username missing")
user = session.get(User, username)
if not user:
# TODO: password is non-null, we must init the pw with something, the model is not made for OIDC
user = User(username=username, password=hash_password(generate_filename("find-something-else")))
session.add(user)
session.commit()
init_user_data(session, username)
return create_tokens(data={"sub": username})
@router.post("/login", response_model=Token)
def login(req: LoginRegisterModel, session: SessionDep) -> Token:
if settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET:
raise HTTPException(status_code=400, detail="OIDC is configured")
db_user = session.get(User, req.username)
if not db_user or not verify_password(req.password, db_user.password):
raise HTTPException(status_code=401, detail="Invalid credentials")
return create_tokens(data={"sub": db_user.username})
@router.post("/register", response_model=Token)
def register(req: LoginRegisterModel, session: SessionDep) -> Token:
if not settings.REGISTER_ENABLE:
raise HTTPException(status_code=400, detail="Registration disabled")
if settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET:
raise HTTPException(status_code=400, detail="OIDC is configured")
db_user = session.get(User, req.username)
if db_user:
raise HTTPException(status_code=409, detail="The resource already exists")
new_user = User(username=req.username, password=hash_password(req.password))
session.add(new_user)
session.commit()
init_user_data(session, new_user.username)
return create_tokens(data={"sub": new_user.username})
@router.post("/refresh")
def refresh_token(refresh_token: str = Body(..., embed=True)):
if not refresh_token:
raise HTTPException(status_code=400, detail="Refresh token expected")
try:
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username = payload.get("sub", None)
if username is None:
raise HTTPException(status_code=401, detail="Invalid Token")
new_access_token = create_access_token(data={"sub": username})
return {"access_token": new_access_token}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Invalid Token")
except jwt.PyJWTError:
raise HTTPException(status_code=401, detail="Invalid Token")