diff --git a/backend/trip/__init__.py b/backend/trip/__init__.py index 5b60188..e4adfb8 100644 --- a/backend/trip/__init__.py +++ b/backend/trip/__init__.py @@ -1 +1 @@ -__version__ = "1.5.0" +__version__ = "1.6.0" diff --git a/backend/trip/config.py b/backend/trip/config.py index 6a1d4da..a549424 100644 --- a/backend/trip/config.py +++ b/backend/trip/config.py @@ -17,6 +17,14 @@ class Settings(BaseSettings): ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 REFRESH_TOKEN_EXPIRE_MINUTES: int = 1440 + REGISTER_ENABLE: bool = True + OIDC_PROTOCOL: str = "https" + OIDC_CLIENT_ID: str = "" + OIDC_CLIENT_SECRET: str = "" + OIDC_HOST: str = "" + OIDC_REALM: str = "master" + OIDC_REDIRECT_URI: str = "" + class Config: env_file = "storage/config.yml" diff --git a/backend/trip/deps.py b/backend/trip/deps.py index 68064bf..09542a1 100644 --- a/backend/trip/deps.py +++ b/backend/trip/deps.py @@ -1,6 +1,7 @@ from typing import Annotated import jwt +from authlib.integrations.httpx_client import OAuth2Client from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlmodel import Session @@ -9,7 +10,7 @@ from .config import settings from .db.core import get_engine from .models.models import User -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") +oauth_password_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") def get_session(): @@ -21,7 +22,7 @@ def get_session(): SessionDep = Annotated[Session, Depends(get_session)] -def get_current_username(token: Annotated[str, Depends(oauth2_scheme)], session: SessionDep) -> str: +def get_current_username(token: Annotated[str, Depends(oauth_password_scheme)], session: SessionDep) -> str: try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) username = payload.get("sub") @@ -34,3 +35,12 @@ def get_current_username(token: Annotated[str, Depends(oauth2_scheme)], session: if not user: raise HTTPException(status_code=401, detail="Invalid Token") return user.username + + +def get_oidc_client(): + return OAuth2Client( + client_id=settings.OIDC_CLIENT_ID, + client_secret=settings.OIDC_CLIENT_SECRET, + scope="openid", + redirect_uri=settings.OIDC_REDIRECT_URI, + ) diff --git a/backend/trip/models/models.py b/backend/trip/models/models.py index 81eac07..96008ec 100644 --- a/backend/trip/models/models.py +++ b/backend/trip/models/models.py @@ -27,6 +27,11 @@ def _prefix_assets_url(filename: str) -> str: return base + filename +class AuthParams(BaseModel): + oidc: str | None + register_enabled: bool + + class TripItemStatusEnum(str, Enum): PENDING = "pending" CONFIRMED = "booked" diff --git a/backend/trip/routers/auth.py b/backend/trip/routers/auth.py index 910eb9d..f6cc3a4 100644 --- a/backend/trip/routers/auth.py +++ b/backend/trip/routers/auth.py @@ -1,16 +1,96 @@ +import json + import jwt from fastapi import APIRouter, Body, HTTPException +from jwt.algorithms import RSAAlgorithm from ..config import settings from ..db.core import init_user_data -from ..deps import SessionDep -from ..models.models import LoginRegisterModel, Token, User +from ..deps import SessionDep, get_oidc_client +from ..models.models import AuthParams, LoginRegisterModel, Token, User from ..security import (create_access_token, create_tokens, hash_password, verify_password) +from ..utils.utils import generate_filename, httpx_get router = APIRouter(prefix="/api/auth", tags=["auth"]) +@router.get("/params", response_model=AuthParams) +async def auth_params() -> AuthParams: + data = {"oidc": None, "register_enabled": settings.REGISTER_ENABLE} + + if settings.OIDC_HOST and settings.OIDC_CLIENT_ID and settings.OIDC_CLIENT_SECRET: + oidc_complete_url = f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}/protocol/openid-connect/auth?client_id={settings.OIDC_CLIENT_ID}&redirect_uri={settings.OIDC_REDIRECT_URI}&response_type=code&scope=openid" + data["oidc"] = oidc_complete_url + + return data + + +@router.post("/oidc/login", response_model=Token) +async def oidc_login(session: SessionDep, code: str = Body(..., embed=True)) -> Token: + if settings.AUTH_METHOD != "oidc": + raise HTTPException(status_code=400, detail="Bad request") + + try: + oidc_client = get_oidc_client() + token = oidc_client.fetch_token( + f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}/protocol/openid-connect/token", + grant_type="authorization_code", + code=code, + ) + except Exception: + raise HTTPException(status_code=401, detail="OIDC login failed") + + id_token = token.get("id_token") + alg = jwt.get_unverified_header(id_token).get("alg") + + match alg: + case "HS256": + decoded = jwt.decode( + id_token, + settings.OIDC_CLIENT_SECRET, + algorithms=alg, + audience=settings.OIDC_CLIENT_ID, + ) + case "RS256": + config = await httpx_get( + f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}/.well-known/openid-configuration" + ) + jwks_uri = config.get("jwks_uri") + jwks = await httpx_get(jwks_uri) + keys = jwks.get("keys") + + for key in keys: + try: + pk = RSAAlgorithm.from_jwk(json.dumps(key)) + decoded = jwt.decode( + id_token, + key=pk, + algorithms=alg, + audience=settings.OIDC_CLIENT_ID, + issuer=f"{settings.OIDC_PROTOCOL}://{settings.OIDC_HOST}/realms/{settings.OIDC_REALM}", + ) + break + except Exception: + continue + case _: + raise HTTPException(status_code=500, detail="OIDC login failed, algorithm not handled") + + if not decoded: + raise HTTPException(status_code=401, detail="Invalid ID token") + + username = decoded.get("preferred_username") + user = session.get(User, username) + if not user: + # TODO: password is non-null, we must init the pw with something, the model is not made for OIDC + user = User(username=username, password=hash_password(generate_filename("find-something-else"))) + session.add(user) + session.commit() + init_user_data(session, username) + + return create_tokens(data={"sub": username}) + + @router.post("/login", response_model=Token) def login(req: LoginRegisterModel, session: SessionDep) -> Token: db_user = session.get(User, req.username) diff --git a/backend/trip/utils/utils.py b/backend/trip/utils/utils.py index f0dc72e..fb642ed 100644 --- a/backend/trip/utils/utils.py +++ b/backend/trip/utils/utils.py @@ -36,6 +36,7 @@ def remove_image(path: str): try: fpath = Path(assets_folder_path() / path) if not fpath.exists(): + # Skips missing file return fpath.unlink() except OSError as exc: @@ -51,6 +52,23 @@ def parse_str_or_date_to_date(cdate: str | date) -> date: return cdate +async def httpx_get(link: str) -> str: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language": "en-US,en;q=0.5", + "Referer": link, + } + + try: + async with httpx.AsyncClient(follow_redirects=True, headers=headers, timeout=5) as client: + response = await client.get(link) + response.raise_for_status() + return response.json() + except Exception: + raise HTTPException(status_code=400, detail="Bad Request") + + async def download_file(link: str, raise_on_error: bool = False) -> str: headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36", diff --git a/src/src/app/components/auth/auth.component.html b/src/src/app/components/auth/auth.component.html index 1faa35c..7803596 100644 --- a/src/src/app/components/auth/auth.component.html +++ b/src/src/app/components/auth/auth.component.html @@ -18,6 +18,13 @@ } + @defer () { + @if (authParams?.oidc) { +
+ +
+ } @else {
+ @if (authParams?.register_enabled) {
@if (isRegistering) {

@@ -70,6 +78,19 @@ an account

} + } + } + } @placeholder (minimum 0.4s) { +
+ +
+
+ +
+
+ +
+ }
@@ -85,7 +106,7 @@ Welcome to TRIP
- Tourism and Recreation Interest Points. + Tourism and Recreational Interest Points.
diff --git a/src/src/app/components/auth/auth.component.ts b/src/src/app/components/auth/auth.component.ts index 456c009..975ff66 100644 --- a/src/src/app/components/auth/auth.component.ts +++ b/src/src/app/components/auth/auth.component.ts @@ -1,23 +1,24 @@ -import { Component } from '@angular/core'; +import { Component } from "@angular/core"; -import { FloatLabelModule } from 'primeng/floatlabel'; +import { FloatLabelModule } from "primeng/floatlabel"; import { FormBuilder, FormGroup, FormsModule, ReactiveFormsModule, Validators, -} from '@angular/forms'; -import { ActivatedRoute, Router } from '@angular/router'; -import { InputTextModule } from 'primeng/inputtext'; -import { ButtonModule } from 'primeng/button'; -import { FocusTrapModule } from 'primeng/focustrap'; -import { AuthService } from '../../services/auth.service'; -import { MessageModule } from 'primeng/message'; -import { HttpErrorResponse } from '@angular/common/http'; +} from "@angular/forms"; +import { ActivatedRoute, Router } from "@angular/router"; +import { InputTextModule } from "primeng/inputtext"; +import { ButtonModule } from "primeng/button"; +import { FocusTrapModule } from "primeng/focustrap"; +import { AuthParams, AuthService, Token } from "../../services/auth.service"; +import { MessageModule } from "primeng/message"; +import { HttpErrorResponse } from "@angular/common/http"; +import { SkeletonModule } from "primeng/skeleton"; @Component({ - selector: 'app-auth', + selector: "app-auth", standalone: true, imports: [ FloatLabelModule, @@ -25,16 +26,18 @@ import { HttpErrorResponse } from '@angular/common/http'; ButtonModule, FormsModule, InputTextModule, + SkeletonModule, FocusTrapModule, MessageModule, ], - templateUrl: './auth.component.html', - styleUrl: './auth.component.scss', + templateUrl: "./auth.component.html", + styleUrl: "./auth.component.scss", }) export class AuthComponent { private redirectURL: string; + authParams: AuthParams | undefined; authForm: FormGroup; - error: string = ''; + error: string = ""; isRegistering: boolean = false; constructor( @@ -43,12 +46,31 @@ export class AuthComponent { private route: ActivatedRoute, private fb: FormBuilder, ) { + this.route.queryParams.subscribe((params) => { + const code = params["code"]; + if (code) { + this.authService.oidcLogin(code).subscribe({ + next: (data) => { + if (!data.access_token) { + this.error = "Authentication failed"; + return; + } + this.router.navigateByUrl(this.redirectURL); + }, + }); + } + }); + + this.authService.authParams().subscribe({ + next: (params) => (this.authParams = params), + }); + this.redirectURL = - this.route.snapshot.queryParams['redirectURL'] || '/home'; + this.route.snapshot.queryParams["redirectURL"] || "/home"; this.authForm = this.fb.group({ - username: ['', { validators: Validators.required }], - password: ['', { validators: Validators.required }], + username: ["", { validators: Validators.required }], + password: ["", { validators: Validators.required }], }); } @@ -58,7 +80,7 @@ export class AuthComponent { } register(): void { - this.error = ''; + this.error = ""; if (this.authForm.valid) { this.authService.register(this.authForm.value).subscribe({ next: () => { @@ -73,17 +95,22 @@ export class AuthComponent { } authenticate(): void { - this.error = ''; - if (this.authForm.valid) { - this.authService.login(this.authForm.value).subscribe({ - next: () => { - this.router.navigateByUrl(this.redirectURL); - }, - error: (err: HttpErrorResponse) => { - this.authForm.reset(); - this.error = err.error.detail; - }, - }); + this.error = ""; + if (this.authParams?.oidc) { + window.location.replace(encodeURI(this.authParams.oidc)); } + + this.authService.login(this.authForm.value).subscribe({ + next: (data) => { + if (!data.access_token) { + this.error = "Authentication failed"; + return; + } + this.router.navigateByUrl(this.redirectURL); + }, + error: () => { + this.authForm.reset(); + }, + }); } } diff --git a/src/src/app/services/auth.service.ts b/src/src/app/services/auth.service.ts index d660c3c..9b0b6ac 100644 --- a/src/src/app/services/auth.service.ts +++ b/src/src/app/services/auth.service.ts @@ -11,6 +11,11 @@ export interface Token { access_token: string; } +export interface AuthParams { + register_enabled: boolean; + oidc?: string; +} + const JWT_TOKEN = "TRIP_AT"; const REFRESH_TOKEN = "TRIP_RT"; const JWT_USER = "TRIP_USER"; @@ -23,7 +28,7 @@ export class AuthService { private httpClient: HttpClient, private router: Router, private apiService: ApiService, - private utilsService: UtilsService + private utilsService: UtilsService, ) { this.apiBaseUrl = this.apiService.apiBaseUrl; } @@ -52,6 +57,10 @@ export class AuthService { return localStorage.getItem(REFRESH_TOKEN) ?? ""; } + authParams(): Observable { + return this.httpClient.get(this.apiBaseUrl + "/auth/params"); + } + storeTokens(tokens: Token): void { this.accessToken = tokens.access_token; this.refreshToken = tokens.refresh_token; @@ -71,26 +80,46 @@ export class AuthService { .pipe( tap((tokens: Token) => { this.accessToken = tokens.access_token; - }) + }), ); } login(authForm: { username: string; password: string }): Observable { - return this.httpClient.post(this.apiBaseUrl + "/auth/login", authForm).pipe( - tap((tokens: Token) => { - this.loggedUser = authForm.username; - this.storeTokens(tokens); - }) - ); + return this.httpClient + .post(this.apiBaseUrl + "/auth/login", authForm) + .pipe( + tap((tokens: Token) => { + this.loggedUser = authForm.username; + this.storeTokens(tokens); + }), + ); } - register(authForm: { username: string; password: string }): Observable { - return this.httpClient.post(this.apiBaseUrl + "/auth/register", authForm).pipe( - tap((tokens: Token) => { - this.loggedUser = authForm.username; - this.storeTokens(tokens); - }) - ); + register(authForm: { + username: string; + password: string; + }): Observable { + return this.httpClient + .post(this.apiBaseUrl + "/auth/register", authForm) + .pipe( + tap((tokens: Token) => { + this.loggedUser = authForm.username; + this.storeTokens(tokens); + }), + ); + } + + oidcLogin(code: string): Observable { + return this.httpClient + .post(this.apiBaseUrl + "/auth/oidc/login", { code }) + .pipe( + tap((data: any) => { + if (data.access_token && data.refresh_token) { + this.loggedUser = this._getTokenUsername(data.access_token); + this.storeTokens(data); + } + }), + ); } logout(custom_msg: string = "", is_error = false): void { @@ -99,7 +128,11 @@ export class AuthService { if (custom_msg) { if (is_error) { - this.utilsService.toast("error", "You must be authenticated", custom_msg); + this.utilsService.toast( + "error", + "You must be authenticated", + custom_msg, + ); } else { this.utilsService.toast("success", "Success", custom_msg); } @@ -135,19 +168,25 @@ export class AuthService { private _b64DecodeUnicode(str: any): string { return decodeURIComponent( Array.prototype.map - .call(this._b64decode(str), (c: any) => "%" + ("00" + c.charCodeAt(0).toString(16)).slice(-2)) - .join("") + .call( + this._b64decode(str), + (c: any) => "%" + ("00" + c.charCodeAt(0).toString(16)).slice(-2), + ) + .join(""), ); } private _b64decode(str: string): string { - const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; + const chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; let output = ""; str = String(str).replace(/=+$/, ""); if (str.length % 4 === 1) { - throw new Error("'atob' failed: The string to be decoded is not correctly encoded."); + throw new Error( + "'atob' failed: The string to be decoded is not correctly encoded.", + ); } /* eslint-disable */ @@ -186,6 +225,20 @@ export class AuthService { return this._b64DecodeUnicode(output); } + private _getTokenUsername(token: string): string { + const decodedToken = this._decodeToken(token); + + if (decodedToken === null) { + return ""; + } + + if (!decodedToken.hasOwnProperty("sub")) { + return ""; + } + + return decodedToken.sub; + } + private _decodeToken(token: string): any { if (!token) { return null;