tripweb/backend/trip/routers/settings.py
2025-10-11 17:04:39 +02:00

337 lines
12 KiB
Python

import json
from datetime import datetime
from pathlib import Path
from typing import Annotated
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import selectinload
from sqlmodel import select
from ..config import settings
from ..deps import SessionDep, get_current_username
from ..models.models import (Category, CategoryRead, Image, Place, PlaceRead,
Trip, TripDay, TripItem, TripRead, User, UserRead,
UserUpdate)
from ..utils.utils import (b64e, b64img_decode, check_update, remove_image,
save_image_to_file)
router = APIRouter(prefix="/api/settings", tags=["settings"])
@router.get("", response_model=UserRead)
def get_user_settings(
session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]
) -> UserRead:
db_user = session.get(User, current_user)
return UserRead.serialize(db_user)
@router.put("", response_model=UserRead)
def put_user_settings(
session: SessionDep, data: UserUpdate, current_user: Annotated[str, Depends(get_current_username)]
) -> UserRead:
db_user = session.get(User, current_user)
user_data = data.model_dump(exclude_unset=True)
if "do_not_display" in user_data:
user_data["do_not_display"] = (
",".join(user_data["do_not_display"]) if user_data["do_not_display"] else ""
)
for key, value in user_data.items():
setattr(db_user, key, value)
session.add(db_user)
session.commit()
session.refresh(db_user)
return UserRead.serialize(db_user)
@router.get("/checkversion")
async def check_version(session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]):
return await check_update()
@router.get("/export")
def export_data(session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]):
trips_query = (
select(Trip)
.where(Trip.user == current_user)
.options(
selectinload(Trip.days)
.selectinload(TripDay.items)
.options(
selectinload(TripItem.place).selectinload(Place.category).selectinload(Category.image),
selectinload(TripItem.place).selectinload(Place.image),
selectinload(TripItem.image),
),
selectinload(Trip.places).options(
selectinload(Place.category).selectinload(Category.image),
selectinload(Place.image),
),
selectinload(Trip.image),
selectinload(Trip.memberships),
selectinload(Trip.shares),
)
)
user_settings = UserRead.serialize(session.get(User, current_user))
categories = session.exec(select(Category).where(Category.user == current_user)).all()
places = session.exec(select(Place).where(Place.user == current_user)).all()
trips = session.exec(trips_query).all()
images = session.exec(select(Image).where(Image.user == current_user)).all()
data = {
"_": {"at": datetime.timestamp(datetime.now())},
"settings": user_settings,
"categories": [CategoryRead.serialize(c) for c in categories],
"places": [PlaceRead.serialize(place, exclude_gpx=False) for place in places],
"trips": [TripRead.serialize(t) for t in trips],
"images": {},
}
for im in images:
try:
with open(Path(settings.ASSETS_FOLDER) / im.filename, "rb") as f:
data["images"][im.id] = b64e(f.read())
except FileNotFoundError:
continue
return data
@router.post("/import")
async def import_data(
session: SessionDep,
current_user: Annotated[str, Depends(get_current_username)],
file: UploadFile = File(...),
):
if file.content_type != "application/json":
raise HTTPException(status_code=415, detail="File must be a JSON file")
try:
content = await file.read()
data = json.loads(content)
except Exception:
raise HTTPException(status_code=400, detail="Invalid file")
existing_categories = {
category.name: category
for category in session.exec(select(Category).filter(Category.user == current_user)).all()
}
categories_to_add = []
for category in data.get("categories", []):
category_name = category.get("name")
category_exists = existing_categories.get(category_name)
if category_exists:
# Update color if present in import data
if category.get("color"):
category_exists.color = category.get("color")
if category.get("image_id"):
b64_image = data.get("images", {}).get(str(category.get("image_id")))
if b64_image:
image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE)
if not filename:
raise HTTPException(status_code=500, detail="Error saving image")
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
if category_exists.image_id:
old_image = session.get(Image, category_exists.image_id)
try:
remove_image(old_image.filename)
session.delete(old_image)
category_exists.image_id = None
session.flush()
except Exception:
raise HTTPException(
status_code=500, detail="Failed to remove old image during import"
)
category_exists.image_id = image.id
session.add(category_exists)
existing_categories[category_name] = category_exists
continue
category_data = {
key: category[key] for key in category.keys() if key not in {"id", "image", "image_id"}
}
category_data["user"] = current_user
if category.get("image_id"):
b64_image = data.get("images", {}).get(str(category.get("image_id")))
if b64_image is None:
continue
image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE)
if filename:
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
category_data["image_id"] = image.id
new_category = Category(**category_data)
categories_to_add.append(new_category)
session.add(new_category)
if categories_to_add:
session.flush()
for category in categories_to_add:
existing_categories[category.name] = category
places = []
places_to_add = []
for place in data.get("places", []):
category_name = place.get("category", {}).get("name")
category = existing_categories.get(category_name)
if not category:
continue
place_data = {
key: place[key]
for key in place.keys()
if key not in {"id", "image", "image_id", "category", "category_id"}
}
place_data["user"] = current_user
place_data["category_id"] = category.id
if place.get("image_id"):
b64_image = data.get("images", {}).get(str(place.get("image_id")))
if b64_image:
image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE)
if filename:
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
place_data["image_id"] = image.id
new_place = Place(**place_data)
places_to_add.append(new_place)
places.append(new_place)
if places_to_add:
session.add_all(places_to_add)
session.flush()
db_user = session.get(User, current_user)
if data.get("settings"):
settings_data = data["settings"]
setting_fields = [
"map_lat",
"map_lng",
"currency",
"tile_layer",
"mode_low_network",
"mode_dark",
"mode_gpx_in_place",
]
for field in setting_fields:
if field in settings_data:
setattr(db_user, field, settings_data[field])
if "do_not_display" in settings_data:
db_user.do_not_display = ",".join(settings_data["do_not_display"])
session.add(db_user)
session.flush()
trip_place_id_map = {p["id"]: new_p.id for p, new_p in zip(data.get("places", []), places)}
trips_to_add = []
days_to_add = []
items_to_add = []
for trip in data.get("trips", []):
trip_data = {
key: trip[key]
for key in trip.keys()
if key not in {"id", "image", "image_id", "places", "days", "shared"}
}
trip_data["user"] = current_user
if trip.get("image_id"):
b64_image = data.get("images", {}).get(str(trip.get("image_id")))
if b64_image:
image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.TRIP_IMAGE_SIZE)
if filename:
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
trip_data["image_id"] = image.id
new_trip = Trip(**trip_data)
session.add(new_trip)
session.flush()
trips_to_add.append(new_trip)
for place in trip.get("places", []):
old_id = place["id"]
new_place_id = trip_place_id_map.get(old_id)
if new_place_id:
db_place = session.get(Place, new_place_id)
if db_place:
new_trip.places.append(db_place)
for day in trip.get("days", []):
day_data = {key: day[key] for key in day if key not in {"id", "items"}}
new_day = TripDay(**day_data, trip_id=new_trip.id, user=current_user)
session.add(new_day)
session.flush()
days_to_add.append(new_day)
for item in day.get("items", []):
item_data = {
key: item[key]
for key in item
if key not in {"id", "place", "place_id", "image", "image_id"}
}
item_data["day_id"] = new_day.id
item_data["user"] = current_user
place = item.get("place")
if (
place
and (place_id := place.get("id"))
and (new_place_id := trip_place_id_map.get(place_id))
):
item_data["place_id"] = new_place_id
if item.get("image_id"):
b64_image = data.get("images", {}).get(str(item.get("image_id")))
if b64_image:
image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.TRIP_IMAGE_SIZE)
if filename:
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
trip_data["image_id"] = image.id
trip_item = TripItem(**item_data)
items_to_add.append(trip_item)
if items_to_add:
session.add_all(items_to_add)
session.commit()
return {
"places": [PlaceRead.serialize(p) for p in places],
"categories": [
CategoryRead.serialize(c)
for c in session.exec(
select(Category).options(selectinload(Category.image)).filter(Category.user == current_user)
).all()
],
"settings": UserRead.serialize(session.get(User, current_user)),
}