optimize export and import

This commit is contained in:
itskovacs 2025-10-11 17:04:39 +02:00
parent c2058f0a39
commit bf7c7bc204

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import selectinload
from sqlmodel import select from sqlmodel import select
from ..config import settings from ..config import settings
@ -53,29 +54,48 @@ async def check_version(session: SessionDep, current_user: Annotated[str, Depend
@router.get("/export") @router.get("/export")
def export_data(session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]): def export_data(session: SessionDep, current_user: Annotated[str, Depends(get_current_username)]):
trips_query = (
select(Trip)
.where(Trip.user == current_user)
.options(
selectinload(Trip.days)
.selectinload(TripDay.items)
.options(
selectinload(TripItem.place).selectinload(Place.category).selectinload(Category.image),
selectinload(TripItem.place).selectinload(Place.image),
selectinload(TripItem.image),
),
selectinload(Trip.places).options(
selectinload(Place.category).selectinload(Category.image),
selectinload(Place.image),
),
selectinload(Trip.image),
selectinload(Trip.memberships),
selectinload(Trip.shares),
)
)
user_settings = UserRead.serialize(session.get(User, current_user))
categories = session.exec(select(Category).where(Category.user == current_user)).all()
places = session.exec(select(Place).where(Place.user == current_user)).all()
trips = session.exec(trips_query).all()
images = session.exec(select(Image).where(Image.user == current_user)).all()
data = { data = {
"_": { "_": {"at": datetime.timestamp(datetime.now())},
"at": datetime.timestamp(datetime.now()), "settings": user_settings,
}, "categories": [CategoryRead.serialize(c) for c in categories],
"categories": [ "places": [PlaceRead.serialize(place, exclude_gpx=False) for place in places],
CategoryRead.serialize(c) "trips": [TripRead.serialize(t) for t in trips],
for c in session.exec(select(Category).filter(Category.user == current_user))
],
"places": [
PlaceRead.serialize(place, exclude_gpx=False)
for place in session.exec(select(Place).filter(Place.user == current_user))
],
"images": {}, "images": {},
"trips": [
TripRead.serialize(c) for c in session.exec(select(Trip).filter(Trip.user == current_user))
],
"settings": UserRead.serialize(session.get(User, current_user)),
} }
images = session.exec(select(Image).where(Image.user == current_user))
for im in images: for im in images:
try:
with open(Path(settings.ASSETS_FOLDER) / im.filename, "rb") as f: with open(Path(settings.ASSETS_FOLDER) / im.filename, "rb") as f:
data["images"][im.id] = b64e(f.read()) data["images"][im.id] = b64e(f.read())
except FileNotFoundError:
continue
return data return data
@ -95,17 +115,21 @@ async def import_data(
except Exception: except Exception:
raise HTTPException(status_code=400, detail="Invalid file") raise HTTPException(status_code=400, detail="Invalid file")
existing_categories = {
category.name: category
for category in session.exec(select(Category).filter(Category.user == current_user)).all()
}
categories_to_add = []
for category in data.get("categories", []): for category in data.get("categories", []):
category_name = category.get("name") category_name = category.get("name")
category_exists = session.exec( category_exists = existing_categories.get(category_name)
select(Category).filter(Category.user == current_user, Category.name == category_name)
).first()
if category_exists: if category_exists:
# Update color if present in import data # Update color if present in import data
if category.get("color"): if category.get("color"):
category_exists.color = category.get("color") category_exists.color = category.get("color")
# Handle image update
if category.get("image_id"): if category.get("image_id"):
b64_image = data.get("images", {}).get(str(category.get("image_id"))) b64_image = data.get("images", {}).get(str(category.get("image_id")))
if b64_image: if b64_image:
@ -117,7 +141,6 @@ async def import_data(
image = Image(filename=filename, user=current_user) image = Image(filename=filename, user=current_user)
session.add(image) session.add(image)
session.flush() session.flush()
session.refresh(image)
if category_exists.image_id: if category_exists.image_id:
old_image = session.get(Image, category_exists.image_id) old_image = session.get(Image, category_exists.image_id)
@ -134,8 +157,7 @@ async def import_data(
category_exists.image_id = image.id category_exists.image_id = image.id
session.add(category_exists) session.add(category_exists)
session.flush() existing_categories[category_name] = category_exists
session.refresh(category_exists)
continue continue
category_data = { category_data = {
@ -150,26 +172,26 @@ async def import_data(
image_bytes = b64img_decode(b64_image) image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE) filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE)
if not filename: if filename:
raise HTTPException(status_code=500, detail="Error saving image")
image = Image(filename=filename, user=current_user) image = Image(filename=filename, user=current_user)
session.add(image) session.add(image)
session.flush() session.flush()
session.refresh(image)
category_data["image_id"] = image.id category_data["image_id"] = image.id
new_category = Category(**category_data) new_category = Category(**category_data)
categories_to_add.append(new_category)
session.add(new_category) session.add(new_category)
if categories_to_add:
session.flush() session.flush()
session.refresh(new_category) for category in categories_to_add:
existing_categories[category.name] = category
places = [] places = []
places_to_add = []
for place in data.get("places", []): for place in data.get("places", []):
category_name = place.get("category", {}).get("name") category_name = place.get("category", {}).get("name")
category = session.exec( category = existing_categories.get(category_name)
select(Category).filter(Category.user == current_user, Category.name == category_name)
).first()
if not category: if not category:
continue continue
@ -183,57 +205,50 @@ async def import_data(
if place.get("image_id"): if place.get("image_id"):
b64_image = data.get("images", {}).get(str(place.get("image_id"))) b64_image = data.get("images", {}).get(str(place.get("image_id")))
if b64_image is None: if b64_image:
continue
image_bytes = b64img_decode(b64_image) image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE) filename = save_image_to_file(image_bytes, settings.PLACE_IMAGE_SIZE)
if not filename: if filename:
raise HTTPException(status_code=500, detail="Error saving image")
image = Image(filename=filename, user=current_user) image = Image(filename=filename, user=current_user)
session.add(image) session.add(image)
session.flush() session.flush()
session.refresh(image)
place_data["image_id"] = image.id place_data["image_id"] = image.id
new_place = Place(**place_data) new_place = Place(**place_data)
session.add(new_place) places_to_add.append(new_place)
session.flush()
places.append(new_place) places.append(new_place)
if places_to_add:
session.add_all(places_to_add)
session.flush()
db_user = session.get(User, current_user) db_user = session.get(User, current_user)
if data.get("settings"): if data.get("settings"):
settings_data = data["settings"] settings_data = data["settings"]
if "map_lat" in settings_data: setting_fields = [
db_user.map_lat = settings_data["map_lat"] "map_lat",
"map_lng",
"currency",
"tile_layer",
"mode_low_network",
"mode_dark",
"mode_gpx_in_place",
]
if "map_lng" in settings_data: for field in setting_fields:
db_user.map_lng = settings_data["map_lng"] if field in settings_data:
setattr(db_user, field, settings_data[field])
if "currency" in settings_data:
db_user.currency = settings_data["currency"]
if "tile_layer" in settings_data:
db_user.tile_layer = settings_data["tile_layer"]
if "do_not_display" in settings_data: if "do_not_display" in settings_data:
db_user.do_not_display = ",".join(settings_data["do_not_display"]) db_user.do_not_display = ",".join(settings_data["do_not_display"])
if "mode_low_network" in settings_data:
db_user.mode_low_network = settings_data["mode_low_network"]
if "mode_dark" in settings_data:
db_user.mode_dark = settings_data["mode_dark"]
if "mode_gpx_in_place" in settings_data:
db_user.mode_gpx_in_place = settings_data["mode_gpx_in_place"]
session.add(db_user) session.add(db_user)
session.flush() session.flush()
session.refresh(db_user)
trip_place_id_map = {p["id"]: new_p.id for p, new_p in zip(data.get("places", []), places)} trip_place_id_map = {p["id"]: new_p.id for p, new_p in zip(data.get("places", []), places)}
trips_to_add = []
days_to_add = []
items_to_add = []
for trip in data.get("trips", []): for trip in data.get("trips", []):
trip_data = { trip_data = {
key: trip[key] key: trip[key]
@ -251,13 +266,12 @@ async def import_data(
image = Image(filename=filename, user=current_user) image = Image(filename=filename, user=current_user)
session.add(image) session.add(image)
session.flush() session.flush()
session.refresh(image)
trip_data["image_id"] = image.id trip_data["image_id"] = image.id
new_trip = Trip(**trip_data) new_trip = Trip(**trip_data)
session.add(new_trip) session.add(new_trip)
session.flush() session.flush()
session.refresh(new_trip) trips_to_add.append(new_trip)
for place in trip.get("places", []): for place in trip.get("places", []):
old_id = place["id"] old_id = place["id"]
@ -272,10 +286,17 @@ async def import_data(
new_day = TripDay(**day_data, trip_id=new_trip.id, user=current_user) new_day = TripDay(**day_data, trip_id=new_trip.id, user=current_user)
session.add(new_day) session.add(new_day)
session.flush() session.flush()
session.refresh(new_day) days_to_add.append(new_day)
for item in day.get("items", []): for item in day.get("items", []):
item_data = {key: item[key] for key in item if key not in {"id", "place"}} item_data = {
key: item[key]
for key in item
if key not in {"id", "place", "place_id", "image", "image_id"}
}
item_data["day_id"] = new_day.id
item_data["user"] = current_user
place = item.get("place") place = item.get("place")
if ( if (
place place
@ -284,16 +305,32 @@ async def import_data(
): ):
item_data["place_id"] = new_place_id item_data["place_id"] = new_place_id
item_data["day_id"] = new_day.id if item.get("image_id"):
b64_image = data.get("images", {}).get(str(item.get("image_id")))
if b64_image:
image_bytes = b64img_decode(b64_image)
filename = save_image_to_file(image_bytes, settings.TRIP_IMAGE_SIZE)
if filename:
image = Image(filename=filename, user=current_user)
session.add(image)
session.flush()
trip_data["image_id"] = image.id
trip_item = TripItem(**item_data) trip_item = TripItem(**item_data)
session.add(trip_item) items_to_add.append(trip_item)
if items_to_add:
session.add_all(items_to_add)
session.commit() session.commit()
return { return {
"places": [PlaceRead.serialize(p) for p in places], "places": [PlaceRead.serialize(p) for p in places],
"categories": [ "categories": [
CategoryRead.serialize(c) CategoryRead.serialize(c)
for c in session.exec(select(Category).filter(Category.user == current_user)) for c in session.exec(
select(Category).options(selectinload(Category.image)).filter(Category.user == current_user)
).all()
], ],
"settings": UserRead.serialize(session.get(User, current_user)), "settings": UserRead.serialize(session.get(User, current_user)),
} }