diff --git a/backend/trip/__init__.py b/backend/trip/__init__.py index e4adfb8..14d9d2f 100644 --- a/backend/trip/__init__.py +++ b/backend/trip/__init__.py @@ -1 +1 @@ -__version__ = "1.6.0" +__version__ = "1.7.0" diff --git a/backend/trip/config.py b/backend/trip/config.py index a549424..e665cf3 100644 --- a/backend/trip/config.py +++ b/backend/trip/config.py @@ -18,11 +18,11 @@ class Settings(BaseSettings): REFRESH_TOKEN_EXPIRE_MINUTES: int = 1440 REGISTER_ENABLE: bool = True + OIDC_DISCOVERY_URL: str = "" OIDC_PROTOCOL: str = "https" OIDC_CLIENT_ID: str = "" OIDC_CLIENT_SECRET: str = "" OIDC_HOST: str = "" - OIDC_REALM: str = "master" OIDC_REDIRECT_URI: str = "" class Config: diff --git a/backend/trip/deps.py b/backend/trip/deps.py index 09542a1..5ac84a1 100644 --- a/backend/trip/deps.py +++ b/backend/trip/deps.py @@ -1,7 +1,6 @@ from typing import Annotated import jwt -from authlib.integrations.httpx_client import OAuth2Client from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlmodel import Session @@ -35,12 +34,3 @@ def get_current_username(token: Annotated[str, Depends(oauth_password_scheme)], if not user: raise HTTPException(status_code=401, detail="Invalid Token") return user.username - - -def get_oidc_client(): - return OAuth2Client( - client_id=settings.OIDC_CLIENT_ID, - client_secret=settings.OIDC_CLIENT_SECRET, - scope="openid", - redirect_uri=settings.OIDC_REDIRECT_URI, - ) diff --git a/backend/trip/routers/auth.py b/backend/trip/routers/auth.py index f6cc3a4..d8973ae 100644 --- a/backend/trip/routers/auth.py +++ b/backend/trip/routers/auth.py @@ -1,16 +1,13 @@ -import json - import jwt from fastapi import APIRouter, Body, HTTPException -from jwt.algorithms import RSAAlgorithm from ..config import settings from ..db.core import init_user_data -from ..deps import SessionDep, get_oidc_client +from ..deps import SessionDep from ..models.models import AuthParams, LoginRegisterModel, Token, User -from ..security import (create_access_token, create_tokens, hash_password, - verify_password) -from ..utils.utils import generate_filename, httpx_get +from ..security import (create_access_token, create_tokens, get_oidc_client, + get_oidc_config, hash_password, verify_password) +from ..utils.utils import generate_filename router = APIRouter(prefix="/api/auth", tags=["auth"]) @@ -20,21 +17,26 @@ async def auth_params() -> AuthParams: data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE} if settings.OIDC_HOST and settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET: - oidc_complete_url = f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}/protocol/openid-connect/auth?client_id={settings.OIDC_CLIENT_ID}&redirect_uri={settings.OIDC_REDIRECT_URI}&response_type=code&scope=openid" - data["oidc"] = oidc_complete_url + oidc_config = await get_oidc_config() + auth_endpoint = oidc_config.get("authorization_endpoint") + data["oidc"] = ( + f"{auth_endpoint}?client_id={settings.OIDC_CLIENT_ID}&redirect_uri={settings.OIDC_REDIRECT_URI}&response_type=code&scope=openid" + ) return data @router.post("/oidc/login", response_model=Token) async def oidc_login(session: SessionDep, code: str = Body(..., embed=True)) -> Token: - if settings.AUTH_METHOD != "oidc": - raise HTTPException(status_code=400, detail="Bad request") + if not (settings.OIDC_HOST or settings.OIDC_CLIENT_ID or settings.OIDC_CLIENT_SECRET): + raise HTTPException(status_code=400, detail="Partial OIDC config") + oidc_config = await get_oidc_config() + token_endpoint = oidc_config.get("token_endpoint") try: oidc_client = get_oidc_client() token = oidc_client.fetch_token( - f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}/protocol/openid-connect/token", + token_endpoint, grant_type="authorization_code", code=code, ) @@ -49,30 +51,25 @@ async def oidc_login(session: SessionDep, code: str = Body(..., embed=True)) -> decoded = jwt.decode( id_token, settings.OIDC_CLIENT_SECRET, - algorithms=alg, + algorithms=["HS256"], audience=settings.OIDC_CLIENT_ID, ) case "RS256": - config = await httpx_get( - f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}/.well-known/openid-configuration" - ) - jwks_uri = config.get("jwks_uri") - jwks = await httpx_get(jwks_uri) - keys = jwks.get("keys") + jwks_uri = oidc_config.get("jwks_uri") + issuer = oidc_config.get("issuer") + jwks_client = jwt.PyJWKClient(jwks_uri) - for key in keys: - try: - pk = RSAAlgorithm.from_jwk(json.dumps(key)) - decoded = jwt.decode( - id_token, - key=pk, - algorithms=alg, - audience=settings.OIDC_CLIENT_ID, - issuer=f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}", - ) - break - except Exception: - continue + try: + signing_key = jwks_client.get_signing_key_from_jwt(id_token) + decoded = jwt.decode( + id_token, + key=signing_key.key, + algorithms=["RS256"], + audience=settings.OIDC_CLIENT_ID, + issuer=issuer, + ) + except Exception: + raise HTTPException(status_code=401, detail="Invalid ID token") case _: raise HTTPException(status_code=500, detail="OIDC login failed, algorithm not handled") @@ -80,6 +77,9 @@ async def oidc_login(session: SessionDep, code: str = Body(..., embed=True)) -> raise HTTPException(status_code=401, detail="Invalid ID token") username = decoded.get("preferred_username") + if not username: + raise HTTPException(status_code=401, detail="OIDC login failed, preferred_username missing") + user = session.get(User, username) if not user: # TODO: password is non-null, we must init the pw with something, the model is not made for OIDC diff --git a/backend/trip/security.py b/backend/trip/security.py index 98555fc..23e2770 100644 --- a/backend/trip/security.py +++ b/backend/trip/security.py @@ -3,12 +3,15 @@ from datetime import UTC, datetime, timedelta import jwt from argon2 import PasswordHasher from argon2 import exceptions as argon_exceptions +from authlib.integrations.httpx_client import OAuth2Client from fastapi import HTTPException from .config import settings from .models.models import Token +from .utils.utils import httpx_get ph = PasswordHasher() +OIDC_CONFIG = {} def hash_password(password: str) -> str: @@ -52,3 +55,25 @@ def verify_exists_and_owns(username: str, obj) -> None: raise PermissionError return None + + +def get_oidc_client(): + return OAuth2Client( + client_id=settings.OIDC_CLIENT_ID, + client_secret=settings.OIDC_CLIENT_SECRET, + scope="openid", + redirect_uri=settings.OIDC_REDIRECT_URI, + ) + + +async def get_oidc_config(): + global OIDC_CONFIG + if OIDC_CONFIG: + return OIDC_CONFIG + + discovery_url = f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/.well-known/openid-configuration" + if settings.OIDC_DISCOVERY_URL: + discovery_url = settings.OIDC_DISCOVERY_URL + + OIDC_CONFIG = await httpx_get(discovery_url) + return OIDC_CONFIG