"""
dataset_utils.py

This module provides utilities for enriching image datasets for personalization tasks.
It supports generating additional images using external scripts, copying images between directories,
and augmenting datasets with both generated and raw images.

Main functionalities:
- Copying all files from one directory to another.
- Enriching a dataset by generating new images using external scripts (ic_light_run.py and generate_preservation.py).
- Optionally copying existing images instead of generating new ones.
- Adding a configurable ratio of raw images to the generated dataset.
- Managing prompts and metadata for generated images.

Key functions:
- copy_all_from_to(from_, to): Copies all files from the source directory to the destination directory.
- enrich_dataset(...): Main function to enrich a dataset with generated and/or copied images, and manage associated metadata.

Configuration:
- Paths to scripts and resources are defined at the top of the file.
- Uses numpy for reproducible random selection and torch for GPU memory management.

Typical usage:
    save_dir, iclight_dir, preservation_dir = enrich_dataset(
        image_dir="path/to/images",
        save_dir="path/to/save",
        iclight_prompt="A prompt for iclight generation",
        preservation_prompt="A prompt for preservation generation",
        ...
    )
"""

from pathlib import Path
import os
import json
import torch
import subprocess
import gc
import shutil
import numpy as np
from pathlib import Path


CURRENT_PATH = Path(__file__).parent.resolve()

PLACES_PATH = str(CURRENT_PATH / "places.json")
IC_LIGHT_RUN = str(CURRENT_PATH/ "ic_light_run.py")
PRESERVATION_RUN = str(CURRENT_PATH/ "generate_preservation.py")

def copy_all_from_to(from_, to):
    for file in Path(from_).glob("*"):
        shutil.copy2(file, to)

def enrich_dataset(
        image_dir,
        save_dir,
        
        generate_iclight_n=32,
        iclight_prompt=None,
        just_cp_iclight_dir=None,
        do_generate_iclight=True,
        add_raw_images_ratio=0.15,

        generate_preservation_n=32,
        preservation_prompt=None,
        just_cp_preservation_dir=None,

        prob=0.9,
        with_preservation=None,
        fixed_place_prompt=None,
):
    """
    I have to enrich the dataset and add more images
    """
    np.random.seed(0)
    image_dir = Path(image_dir)
    save_dir = Path(save_dir)
    
    save_iclight_dir = save_dir / "generated_images_iclight"
    save_preservation_dir = save_dir / "generated_images_obj"

    if with_preservation is None:
        with_preservation = (prob < 1.0)
    
    print(f"[INFO] Build dataset: with_preservation = {with_preservation}, prob = {prob}")
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(save_iclight_dir, exist_ok=True)

    if with_preservation:
        os.makedirs(save_preservation_dir, exist_ok=True)

        if not just_cp_preservation_dir is None:
            # copy all images and maybe text descriptions
            copy_all_from_to(just_cp_preservation_dir, save_preservation_dir)
            print("[INFO] Done copying preservation images")
        else:
            # generate object images from scratch
            cmd = [
                f"python3", PRESERVATION_RUN,
                "--prompt", preservation_prompt,
                "--num_images", generate_preservation_n,
                "--output_dir", str(save_preservation_dir),
            ]
            cmd = [str(el) for el in cmd]
            subprocess.run(cmd)
            print("[INFO] Done generating preservation images")

        # free resources
        gc.collect()
        torch.cuda.empty_cache()
    else:
        print("[INFO] Do not use preservation")

    if not just_cp_iclight_dir is None:
            # just copy images
            print("[INFO] Copied preservation images")
            print("from:", just_cp_iclight_dir)
            print("to:", save_iclight_dir)
            if str(Path(just_cp_iclight_dir)) != str(save_iclight_dir):
                copy_all_from_to(just_cp_iclight_dir, save_iclight_dir)

            if len(os.listdir(just_cp_iclight_dir)) - 1 < generate_iclight_n:
                print(f"[WARN] {just_cp_iclight_dir} is small")
    else:
        if not do_generate_iclight:
            # just copy images without iclight image generation
            copy_all_from_to(image_dir, save_iclight_dir)
            # generate a json file for it
            with open(str(save_iclight_dir / "prompts.json"), "w") as f:
                prompts = []
                iclight_files = save_iclight_dir.rglob("*")
                iclight_files = sorted(iclight_files)
                for file in iclight_files:
                    prompt_curr = {
                            "prompt": iclight_prompt,
                            "fg_name": file.name,
                            "bg_name": "",
                            "result_image": file.name,
                        }
                    prompts.append(prompt_curr)
                json.dump(prompts, f, indent=4)
        else:
            # generate more
            if not fixed_place_prompt is None:
                fixed_place_prompt = ["--fixed_place_prompt", fixed_place_prompt]
            else:
                fixed_place_prompt = []

            cmd = [
                f"python3", IC_LIGHT_RUN,
                "--process_dir",
                "--input_fg_dir", str(image_dir),
                "--prompt", iclight_prompt,
                "--prompts_dir", PLACES_PATH,
                *fixed_place_prompt,
                "--num_images", generate_iclight_n,
                "--output_dir", str(save_iclight_dir),
                "--position_light", "left",
            ]
            cmd = [str(el) for el in cmd]
            
            subprocess.run(cmd)
            gc.collect()
            torch.cuda.empty_cache()

            print("[INFO] Done generating preservation images")

            if add_raw_images_ratio >= 0.0:
                with open(save_iclight_dir / "prompts.json", "r") as f:
                    curr_prompts = json.load(f)
                n = len(curr_prompts) # should be generate_iclight_n
                # copy some random images to the initial dataset
                raw_images = list(image_dir.rglob("*"))
                add_raw_images_num = int(np.ceil(add_raw_images_ratio / (1 - add_raw_images_ratio) * n))
                for j in range(add_raw_images_num):
                    random_img_path_in = np.random.choice(raw_images)
                    random_img_path_out = save_iclight_dir / f"{j + generate_iclight_n:05}.png"
                    shutil.copy2(random_img_path_in, random_img_path_out)
                    curr_prompts.append({
                        "prompt": iclight_prompt,
                        "fg_name": random_img_path_in.name,
                        "bg_name": "",
                        "result_image": random_img_path_out.name,
                    })
                
                with open(str(save_iclight_dir / "prompts.json"), "w") as f:
                    json.dump(curr_prompts, f, indent=4)
                
                print("[INFO] Done copying raw images")
    
    # free resources
    gc.collect()
    torch.cuda.empty_cache()

    print("[INFO] Done dataset preparation!")
    return save_dir, save_iclight_dir, save_preservation_dir
