From bf7c7bc204fc1d240c48794ef86d25808afae2dd Mon Sep 17 00:00:00 2001 From: itskovacs Date: Sat, 11 Oct 2025 17:04:39 +0200 Subject: [PATCH] :zap: optimize export and import --- backend/trip/routers/settings.py | 195 ++++++++++++++++++------------- 1 file changed, 116 insertions(+), 79 deletions(-) diff --git a/backend/trip/routers/settings.py b/backend/trip/routers/settings.py index 82cf81d..68775a3 100644 --- a/backend/trip/routers/settings.py +++ b/backend/trip/routers/settings.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Annotated from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from sqlalchemy.orm import selectinload from sqlmodel import select from ..config import settings @@ -53,29 +54,48 @@ async def check_version(session: SessionDep, current_user: Annotated[str, Depend @router.get("/export") def export_data(session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]): + trips_query = ( + select(Trip) + .where(Trip.user == current_user) + .options( + selectinload(Trip.days) + .selectinload(TripDay.items) + .options( + selectinload(TripItem.place).selectinload(Place.category).selectinload(Category.image), + selectinload(TripItem.place).selectinload(Place.image), + selectinload(TripItem.image), + ), + selectinload(Trip.places).options( + selectinload(Place.category).selectinload(Category.image), + selectinload(Place.image), + ), + selectinload(Trip.image), + selectinload(Trip.memberships), + selectinload(Trip.shares), + ) + ) + + user_settings = UserRead.serialize(session.get(User, current_user)) + categories = session.exec(select(Category).where(Category.user == current_user)).all() + places = session.exec(select(Place).where(Place.user == current_user)).all() + trips = session.exec(trips_query).all() + images = session.exec(select(Image).where(Image.user == current_user)).all() + data = { - "_": { - "at": datetime.timestamp(datetime.now()), - }, - "categories": [ - CategoryRead.serialize(c) - for c in session.exec(select(Category).filter(Category.user == current_user)) - ], - "places": [ - PlaceRead.serialize(place, exclude_gpx=False) - for place in session.exec(select(Place).filter(Place.user == current_user)) - ], + "_": {"at": datetime.timestamp(datetime.now())}, + "settings": user_settings, + "categories": [CategoryRead.serialize(c) for c in categories], + "places": [PlaceRead.serialize(place, exclude_gpx=False) for place in places], + "trips": [TripRead.serialize(t) for t in trips], "images": {}, - "trips": [ - TripRead.serialize(c) for c in session.exec(select(Trip).filter(Trip.user == current_user)) - ], - "settings": UserRead.serialize(session.get(User, current_user)), } - images = session.exec(select(Image).where(Image.user == current_user)) for im in images: - with open(Path(settings.ASSETS_FOLDER) / im.filename, "rb") as f: - data["images"][im.id] = b64e(f.read()) + try: + with open(Path(settings.ASSETS_FOLDER) / im.filename, "rb") as f: + data["images"][im.id] = b64e(f.read()) + except FileNotFoundError: + continue return data @@ -95,17 +115,21 @@ async def import_data( except Exception: raise HTTPException(status_code=400, detail="Invalid file") + existing_categories = { + category.name: category + for category in session.exec(select(Category).filter(Category.user == current_user)).all() + } + + categories_to_add = [] for category in data.get("categories", []): category_name = category.get("name") - category_exists = session.exec( - select(Category).filter(Category.user == current_user, Category.name == category_name) - ).first() + category_exists = existing_categories.get(category_name) + if category_exists: # Update color if present in import data if category.get("color"): category_exists.color = category.get("color") - # Handle image update if category.get("image_id"): b64_image = data.get("images", {}).get(str(category.get("image_id"))) if b64_image: @@ -117,7 +141,6 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() - session.refresh(image) if category_exists.image_id: old_image = session.get(Image, category_exists.image_id) @@ -134,8 +157,7 @@ async def import_data( category_exists.image_id = image.id session.add(category_exists) - session.flush() - session.refresh(category_exists) + existing_categories[category_name] = category_exists continue category_data = { @@ -150,26 +172,26 @@ async def import_data( image_bytes = b64img_decode(b64_image) filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE) - if not filename: - raise HTTPException(status_code=500, detail="Error saving image") - - image = Image(filename=filename, user=current_user) - session.add(image) - session.flush() - session.refresh(image) - category_data["image_id"] = image.id + if filename: + image = Image(filename=filename, user=current_user) + session.add(image) + session.flush() + category_data["image_id"] = image.id new_category = Category(**category_data) + categories_to_add.append(new_category) session.add(new_category) + + if categories_to_add: session.flush() - session.refresh(new_category) + for category in categories_to_add: + existing_categories[category.name] = category places = [] + places_to_add = [] for place in data.get("places", []): category_name = place.get("category", {}).get("name") - category = session.exec( - select(Category).filter(Category.user == current_user, Category.name == category_name) - ).first() + category = existing_categories.get(category_name) if not category: continue @@ -183,57 +205,50 @@ async def import_data( if place.get("image_id"): b64_image = data.get("images", {}).get(str(place.get("image_id"))) - if b64_image is None: - continue - - image_bytes = b64img_decode(b64_image) - filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE) - if not filename: - raise HTTPException(status_code=500, detail="Error saving image") - - image = Image(filename=filename, user=current_user) - session.add(image) - session.flush() - session.refresh(image) - place_data["image_id"] = image.id + if b64_image: + image_bytes = b64img_decode(b64_image) + filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE) + if filename: + image = Image(filename=filename, user=current_user) + session.add(image) + session.flush() + place_data["image_id"] = image.id new_place = Place(**place_data) - session.add(new_place) - session.flush() + places_to_add.append(new_place) places.append(new_place) + if places_to_add: + session.add_all(places_to_add) + session.flush() + db_user = session.get(User, current_user) if data.get("settings"): settings_data = data["settings"] - if "map_lat" in settings_data: - db_user.map_lat = settings_data["map_lat"] + setting_fields = [ + "map_lat", + "map_lng", + "currency", + "tile_layer", + "mode_low_network", + "mode_dark", + "mode_gpx_in_place", + ] - if "map_lng" in settings_data: - db_user.map_lng = settings_data["map_lng"] - - if "currency" in settings_data: - db_user.currency = settings_data["currency"] - - if "tile_layer" in settings_data: - db_user.tile_layer = settings_data["tile_layer"] + for field in setting_fields: + if field in settings_data: + setattr(db_user, field, settings_data[field]) if "do_not_display" in settings_data: db_user.do_not_display = ",".join(settings_data["do_not_display"]) - if "mode_low_network" in settings_data: - db_user.mode_low_network = settings_data["mode_low_network"] - - if "mode_dark" in settings_data: - db_user.mode_dark = settings_data["mode_dark"] - - if "mode_gpx_in_place" in settings_data: - db_user.mode_gpx_in_place = settings_data["mode_gpx_in_place"] - session.add(db_user) session.flush() - session.refresh(db_user) trip_place_id_map = {p["id"]: new_p.id for p, new_p in zip(data.get("places", []), places)} + trips_to_add = [] + days_to_add = [] + items_to_add = [] for trip in data.get("trips", []): trip_data = { key: trip[key] @@ -251,13 +266,12 @@ async def import_data( image = Image(filename=filename, user=current_user) session.add(image) session.flush() - session.refresh(image) trip_data["image_id"] = image.id new_trip = Trip(**trip_data) session.add(new_trip) session.flush() - session.refresh(new_trip) + trips_to_add.append(new_trip) for place in trip.get("places", []): old_id = place["id"] @@ -272,10 +286,17 @@ async def import_data( new_day = TripDay(**day_data, trip_id=new_trip.id, user=current_user) session.add(new_day) session.flush() - session.refresh(new_day) + days_to_add.append(new_day) for item in day.get("items", []): - item_data = {key: item[key] for key in item if key not in {"id", "place"}} + item_data = { + key: item[key] + for key in item + if key not in {"id", "place", "place_id", "image", "image_id"} + } + item_data["day_id"] = new_day.id + item_data["user"] = current_user + place = item.get("place") if ( place @@ -284,16 +305,32 @@ async def import_data( ): item_data["place_id"] = new_place_id - item_data["day_id"] = new_day.id + if item.get("image_id"): + b64_image = data.get("images", {}).get(str(item.get("image_id"))) + if b64_image: + image_bytes = b64img_decode(b64_image) + filename = save_image_to_file(image_bytes, settings.TRIP_IMAGE_SIZE) + if filename: + image = Image(filename=filename, user=current_user) + session.add(image) + session.flush() + trip_data["image_id"] = image.id + trip_item = TripItem(**item_data) - session.add(trip_item) + items_to_add.append(trip_item) + + if items_to_add: + session.add_all(items_to_add) + session.commit() return { "places": [PlaceRead.serialize(p) for p in places], "categories": [ CategoryRead.serialize(c) - for c in session.exec(select(Category).filter(Category.user == current_user)) + for c in session.exec( + select(Category).options(selectinload(Category.image)).filter(Category.user == current_user) + ).all() ], "settings": UserRead.serialize(session.get(User, current_user)), }