:fix: Delete created files on failed import

This commit is contained in:
itskovacs 2025-11-11 18:23:06 +01:00
parent cfe4baa794
commit e8c9fe86f2

View File

@ -19,7 +19,7 @@ from ..models.models import (Backup, BackupStatus, Category, CategoryRead,
TripRead, User, UserRead)
from .date import dt_utc, iso_to_dt
from .utils import (assets_folder_path, attachments_trip_folder_path,
b64img_decode, save_image_to_file)
b64img_decode, remove_image, save_image_to_file)
def process_backup_export(session: SessionDep, backup_id: int):
@ -147,7 +147,6 @@ async def process_backup_import(
except Exception:
raise HTTPException(status_code=400, detail="Invalid file")
try:
with ZipFile(io.BytesIO(zip_content), "r") as zipf:
zip_filenames = zipf.namelist()
if "data.json" not in zip_filenames:
@ -170,6 +169,9 @@ async def process_backup_import(
if path.startswith("attachments/") and not path.endswith("/")
}
error_details = "Bad request"
created_image_filenames = []
created_attachment_trips = []
try:
existing_categories = {
category.name: category
@ -196,6 +198,7 @@ async def process_backup_import(
session.add(image)
session.flush()
session.refresh(image)
created_image_filenames.append(filename)
if category_exists.image_id:
old_image = session.get(Image, category_exists.image_id)
@ -213,9 +216,7 @@ async def process_backup_import(
continue
new_category = {
key: category[key]
for key in category.keys()
if key not in {"id", "image", "image_id"}
key: category[key] for key in category.keys() if key not in {"id", "image", "image_id"}
}
new_category["user"] = current_user
@ -230,6 +231,7 @@ async def process_backup_import(
session.add(image)
session.flush()
session.refresh(image)
created_image_filenames.append(filename)
new_category["image_id"] = image.id
except Exception:
pass
@ -270,6 +272,7 @@ async def process_backup_import(
session.add(image)
session.flush()
session.refresh(image)
created_image_filenames.append(filename)
new_place["image_id"] = image.id
except Exception:
pass
@ -341,6 +344,7 @@ async def process_backup_import(
session.add(image)
session.flush()
session.refresh(image)
created_image_filenames.append(filename)
new_trip["image_id"] = image.id
except Exception:
pass
@ -372,13 +376,14 @@ async def process_backup_import(
new_attachment = {
key: attachment[key]
for key in attachment
if key not in {"id", "trip_id", "trip"}
if key not in {"id", "trip_id", "trip", "uploaded_by"}
}
new_attachment["trip_id"] = new_trip.id
new_attachment["user"] = current_user
new_attachment["uploaded_by"] = current_user
new_attachment_obj = TripAttachment(**new_attachment)
attachment_path = attachments_trip_folder_path(new_trip.id) / stored_filename
created_attachment_trips.append(new_trip.id)
attachment_path.write_bytes(attachment_bytes)
session.add(new_attachment_obj)
session.flush()
@ -397,13 +402,19 @@ async def process_backup_import(
session.flush()
for item in day.get("items", []):
if item.get("paid_by"):
u = item.get("paid_by")
db_user = session.get(User, u)
if not db_user:
error_details = f"User <{u}> does not exist and is specified in Paid By"
raise
item_data = {
key: item[key]
for key in item
if key not in {"id", "place", "place_id", "image", "image_id", "attachments"}
}
item_data["day_id"] = new_day.id
item_data["user"] = current_user
place = item.get("place")
if place and (place_id := place.get("id")):
@ -421,6 +432,7 @@ async def process_backup_import(
session.add(image)
session.flush()
session.refresh(image)
created_image_filenames.append(filename)
item_data["image_id"] = image.id
except Exception:
pass
@ -482,14 +494,21 @@ async def process_backup_import(
"settings": UserRead.serialize(session.get(User, current_user)),
}
except Exception as exc:
except Exception:
session.rollback()
print(exc)
raise HTTPException(status_code=400, detail="Bad request")
except Exception as exc:
print(exc)
raise HTTPException(status_code=400, detail="Bad request")
for filename in created_image_filenames:
remove_image(filename)
for trip_id in created_attachment_trips:
try:
folder = attachments_trip_folder_path(trip_id)
if not folder.exists():
return
for file in folder.iterdir():
file.unlink()
folder.rmdir()
except Exception:
pass
raise HTTPException(status_code=400, detail=error_details)
async def process_legacy_import(