


from util.logger import logger

from typing import Optional, List, Dict

from omegaconf import DictConfig

import concurrent.futures as cf

from pathlib import Path

import random

import gc

from tqdm.auto import tqdm

import numpy as np

import torch

from diffusers.schedulers import DDIMScheduler, DDIMInverseScheduler

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.numpy_util import tsfm_to_1d_array
from util.pkl_util import load_pkl_cached
from util.image_util import save_pil_as_png
from util.yaml_util import (
    save_yaml, 
    convert_numpy_type_to_native_type
)
from util.torch_util import get_latent
from util.pipeline_util import (
    load_pipeline, load_scheduler, 
    get_inference_step_minus_one, 
    get_pipeline_category_and_type, 
    img_latent_to_pil, 
    get_folder_name
)

from task.util.seed_list_util import prepare_seed_list

from prompt_manager.util import get_prompt_manager


@torch.no_grad()
def run_z_sampling_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    # ---------= [Pipeline] =---------
    logger(f"[Pipeline] Loading started. ")

    pipeline_type = get_true_value(cfg["pipeline"]["pipeline_type"])
    pipeline_path = get_true_value(cfg["pipeline"]["pipeline_path"])
    pipeline_torch_dtype = get_true_value(cfg["pipeline"]["torch_dtype"])
    pipeline_variant = get_true_value(cfg["pipeline"]["variant"])
    
    logger(f"    pipeline_type: {pipeline_type}")
    logger(f"    pipeline_path: {pipeline_path}")
    logger(f"    pipeline_torch_dtype: {pipeline_torch_dtype}")
    logger(f"    pipeline_variant: {pipeline_variant}")

    scheduler_type = get_true_value(cfg["pipeline"]["scheduler"])

    logger(f"    scheduler_type: {scheduler_type}")

    logger(
        f"[Pipeline] Loading finished. "
        "\n"
    )

    # ---------= [Prompt List] =---------
    logger(f"[Prompt List] Loading started. ")

    num_prompt = get_true_value(cfg["task"]["prompt_list"]["num_prompt"])
    prompt_manager_dict = get_true_value(cfg["task"]["prompt_list"]["prompt_manager_dict"])

    logger(f"    num_prompt: {num_prompt}")
    logger(f"    prompt_manager_dict: {prompt_manager_dict}")

    logger(
        f"[Prompt List] Loading finished. "
        "\n"
    )

    # ---------= [Init Latent] =---------
    logger(f"[Init Latent] Loading started. ")

    init_latent_root_path = get_true_value(cfg["task"]["init_latent"]["root_path"])

    logger(f"    init_latent_root_path: {init_latent_root_path}")

    init_latent_random = get_true_value(cfg["task"]["init_latent"]["random"])

    logger(f"    init_latent_random: {init_latent_random}")

    if init_latent_random:
        init_latent_seed_st = get_true_value(cfg["task"]["init_latent"]["seed_st"])
        init_latent_seed_ed = get_true_value(cfg["task"]["init_latent"]["seed_ed"])
        
        logger(f"    init_latent_seed_st: {init_latent_seed_st}")
        logger(f"    init_latent_seed_ed: {init_latent_seed_ed}")
    else:
        init_latent_seed_list = get_true_value(cfg["task"]["init_latent"]["seed_list"])
        init_latent_seed_auto_increment = get_true_value(cfg["task"]["init_latent"]["seed_auto_increment"])
        
        logger(f"    init_latent_seed_list: {init_latent_seed_list}")
        logger(f"    init_latent_seed_auto_increment: {init_latent_seed_auto_increment}")

    logger(
        f"[Init Latent] Loading finished. "
        "\n"
    )

    # # ---------= [Eps] =---------
    # logger(f"[Eps] Loading started. ")

    # eps_root_path = get_true_value(cfg["task"]["eps"]["root_path"])

    # logger(f"    eps_root_path: {eps_root_path}")

    # eps_random = get_true_value(cfg["task"]["eps"]["random"])

    # logger(f"    eps_random: {eps_random}")

    # if eps_random:
    #     eps_seed_st = get_true_value(cfg["task"]["eps"]["seed_st"])
    #     eps_seed_ed = get_true_value(cfg["task"]["eps"]["seed_ed"])
        
    #     logger(f"    eps_seed_st: {eps_seed_st}")
    #     logger(f"    eps_seed_ed: {eps_seed_ed}")
    # else:
    #     eps_seed_list = get_true_value(cfg["task"]["eps"]["seed_list"])
    #     eps_seed_auto_increment = get_true_value(cfg["task"]["eps"]["seed_auto_increment"])
        
    #     logger(f"    eps_seed_list: {eps_seed_list}")
    #     logger(f"    eps_seed_auto_increment: {eps_seed_auto_increment}")

    # logger(
    #     f"[Eps] Loading finished. "
    #     "\n"
    # )

    # # ---------= [Eta] =---------
    # logger(f"[Eta] Loading started. ")

    # eta_random = get_true_value(cfg["task"]["eta"]["random"])

    # logger(f"    eta_random: {eta_random}")

    # if eta_random:
    #     eta_st = get_true_value(cfg["task"]["eta"]["eta_st"])
    #     eta_ed = get_true_value(cfg["task"]["eta"]["eta_ed"])
        
    #     logger(f"    eta_st: {eta_st}")
    #     logger(f"    eta_ed: {eta_ed}")
    # else:
    #     eta_list = get_true_value(cfg["task"]["eta"]["eta_list"])

    #     logger(f"    eta_list: {eta_list}")

    # logger(
    #     f"[Eta] Loading finished. "
    #     "\n"
    # )

    # ---------= [Sample] =---------
    logger(f"[Sample] Loading started. ")

    prompt_2 = get_true_value(cfg["task"]["sample"]["prompt_2"])
    negative_prompt = get_true_value(cfg["task"]["sample"]["negative_prompt"])
    negative_prompt_2 = get_true_value(cfg["task"]["sample"]["negative_prompt_2"])

    logger(f"    prompt_2: {prompt_2}")
    logger(f"    negative_prompt: {negative_prompt}")
    logger(f"    negative_prompt_2: {negative_prompt_2}")

    height = get_true_value(cfg["task"]["sample"]["height"])
    width = get_true_value(cfg["task"]["sample"]["width"])
    down_sampling_ratio = get_true_value(cfg["task"]["sample"]["down_sampling_ratio"])

    logger(f"    height: {height}")
    logger(f"    width: {width}")
    logger(f"    down_sampling_ratio: {down_sampling_ratio}")

    num_inference_step = get_true_value(cfg["task"]["sample"]["num_inference_step"])
    guidance_scale = get_true_value(cfg["task"]["sample"]["guidance_scale"])

    logger(f"    num_inference_step: {num_inference_step}")
    logger(f"    guidance_scale: {guidance_scale}")

    logger(
        f"[Sample] Loading finished. "
        "\n"
    )

    # ---------= [Z-Sampling] =---------
    logger(f"[Z-Sampling] Loading started. ")
    
    inv_guidance_scale = get_true_value(cfg["task"]["z_sampling"]["inv_guidance_scale"])
    z_sampling_max_timestep_idx = get_true_value(cfg["task"]["z_sampling"]["max_timestep_idx"])
    num_zig_zag_per_step = get_true_value(cfg["task"]["z_sampling"]["num_zig_zag_per_step"])

    logger(f"    inv_guidance_scale: {inv_guidance_scale}")
    logger(f"    z_sampling_max_timestep_idx: {z_sampling_max_timestep_idx}")
    logger(f"    num_zig_zag_per_step: {num_zig_zag_per_step}")

    logger(
        f"[Z-Sampling] Loading finished. "
        "\n"
    )

    # ---------= [Task] =---------
    logger(f"[Task] Loading started. ")
    
    num_sample_per_prompt = get_true_value(cfg["task"]["task"]["num_sample_per_prompt"])
    batch_size = get_true_value(cfg["task"]["task"]["batch_size"])

    logger(f"    num_sample_per_prompt: {num_sample_per_prompt}")
    logger(f"    batch_size: {batch_size}")

    logger(
        f"[Task] Loading finished. "
        "\n"
    )

    # ---------= [Save Sample] =---------
    logger(f"[Save Sample] Loading started. ")
    
    save_sample_root_path = get_true_value(cfg["task"]["save_sample"]["save_sample_root_path"])

    logger(f"    save_sample_root_path: {save_sample_root_path}")

    logger(
        f"[Save Sample] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Pipeline] =---------
    (
        pipeline_category_name, 
        pipeline_type_name
    ) = get_pipeline_category_and_type(
        pipeline_path = pipeline_path
    )

    pipeline = load_pipeline(
        pipeline_type = pipeline_type, 
        pipeline_path = pipeline_path, 
        torch_dtype = pipeline_torch_dtype, 
        variant = pipeline_variant
    )

    pipeline.scheduler = DDIMScheduler.from_config(
        pipeline.scheduler.config
    )

    pipeline.inv_scheduler = DDIMInverseScheduler.from_pretrained(
        pipeline_path, 
        subfolder = "scheduler"
    )

    # pipeline.scheduler = load_scheduler(
    #     pipeline = pipeline, 
    #     scheduler_type = scheduler_type
    # )

    # pipeline.inv_scheduler = load_scheduler(
    #     scheduler_type = "DDIMInverseScheduler", 
    #     pipeline = pipeline
    # )
    
    inference_step_minus_one = get_inference_step_minus_one(scheduler_type)

    if pipeline_category_name == "sdxl_family":
        from .pipeline.pipeline_stable_diffusion_xl import register_pipeline_stable_diffusion_xl

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_stable_diffusion_xl(pipeline)
    else:
        raise NotImplementedError(
            f"Only support SDXL pipeline. "
        )

    # if pipeline_category_name == "sd_family":
    #     from my_diffusers.pipeline_stable_diffusion import register_pipeline_stable_diffusion

    #     # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
    #     register_pipeline_stable_diffusion(pipeline)

    # elif pipeline_category_name == "sd_3_family":
    #     from my_diffusers.pipeline_stable_diffusion_3 import register_pipeline_stable_diffusion_3

    #     # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
    #     register_pipeline_stable_diffusion_3(pipeline)

    # elif pipeline_category_name == "hunyuan_dit_family":
    #     from my_diffusers.pipeline_hunyuandit import register_pipeline_hunyuandit

    #     # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
    #     register_pipeline_hunyuandit(pipeline)

    # else:
    #     raise NotImplementedError(
    #         f"Unsupported `pipeline_category_name`, got `{pipeline_category_name}`. "
    #     )
    
    # register_scheduling_ddim(
    #     scheduler = pipeline.scheduler
    # )
    # register_scheduling_ddim_inverse(
    #     scheduler = pipeline.inv_scheduler
    # )
    
    # move to GPU after registration
    pipeline.to(device)

    # save VRAM by offloading the model to CPU
    pipeline.enable_model_cpu_offload()

    logger(f"    pipeline: {type(pipeline)}")
    # logger(f"    pipeline: {pipeline}")

    # ---------= [Prompt List] =---------
    prompt_manager = get_prompt_manager(
        prompt_manager_dict = prompt_manager_dict
    )

    prompt_manager.load_prompt_list()
    prompt_manager.prepare_everything(
        shuffle = True
    )

    prompt_list = prompt_manager.prompt_list
    folder_name_list = prompt_manager.folder_name_list
    
    if num_prompt is not None:
        prompt_list = prompt_list[: num_prompt]
        folder_name_list = folder_name_list[: num_prompt]
    else:
        num_prompt = len(prompt_list)
    
    # prompt_tuple_list: (prompt_idx, prompt)
    # prompt_tuple_list.shape = (num_prompt, 2)
    prompt_tuple_list = [
        (i, prompt) \
            for i, prompt in enumerate(prompt_list)
    ]
    
    # prompt_tuple_list.shape = (num_prompt * num_sample_per_prompt, 2)
    prompt_tuple_list = np.repeat(
        np.array(
            prompt_tuple_list, 
            dtype = object
        ), 
        num_sample_per_prompt, 

        axis = 0
    ).tolist()

    # ---------= [Prepare Z-Sampling] =---------
    if z_sampling_max_timestep_idx is None:
        logger(
            f"`z_sampling_max_timestep_idx` is not provided, set to `num_inference_step - 1` ({num_inference_step - 1}). "
        )

        z_sampling_max_timestep_idx = num_inference_step - 1

    # ---------= [Prepare Task] =---------
    num_sample = num_prompt * num_sample_per_prompt

    num_batch = (num_sample + batch_size - 1) // batch_size

    # ---------= [Prepare Init Latent] =---------
    latent_num_channel = 4
    latent_height = height // down_sampling_ratio
    latent_width = width // down_sampling_ratio

    latent_shape = (latent_num_channel, latent_height, latent_width)
    latent_shape_str = f"{latent_num_channel}_{latent_height}_{latent_width}"

    init_latent_root_path = Path(init_latent_root_path)
    init_latent_root_path = init_latent_root_path / latent_shape_str

    if not init_latent_random:
        init_latent_seed_list = prepare_seed_list(
            seed_list = init_latent_seed_list, 
            target_length = num_sample_per_prompt, 

            auto_inrement = init_latent_seed_auto_increment
        )

        init_latent_list = [
            get_latent(
                shape = (4, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in init_latent_seed_list
        ]

    # # ---------= [Prepare Eps] =---------
    # eps_root_path = Path(eps_root_path)
    # eps_root_path = eps_root_path / latent_shape_str

    # if not eps_random:
    #     eps_seed_list = prepare_seed_list(
    #         seed_list = eps_seed_list, 
    #         target_length = num_inference_step, 

    #         auto_inrement = eps_seed_auto_increment
    #     )

    #     eps_list = [
    #         get_latent(
    #             shape = (4, latent_height, latent_width), 
                
    #             seed = seed, 

    #             device = device, 

    #             dtype = pipeline_torch_dtype
    #         ) \
    #             for seed in eps_seed_list
    #     ]

    # # ---------= [Prepare Eta] =---------
    # if not eta_random:
    #     eta_list = list(eta_list)

    #     eta_list = tsfm_to_1d_array(
    #         array = eta_list, 
    #         target_length = num_inference_step, 

    #         dtype = pipeline_torch_dtype
    #     ).tolist()

    #     eta_list = [
    #         torch.tensor(
    #             [eta], 
                
    #             dtype = get_attr("torch", pipeline_torch_dtype), 
    #             device = device
    #         ) \
    #             for eta in eta_list
    #     ]
    
    # ---------= [Prepare Save Sample] =---------
    save_sample_root_path = Path(save_sample_root_path)
    pipeline_root_path = save_sample_root_path / pipeline_type_name

    dataset_type_str = prompt_manager_dict["prompt_manager_type"]
    dataset_root_path = pipeline_root_path / dataset_type_str

    if exp_name in ["ddpm", "deterministic_ddim"]:
        exp_name_str = exp_name
        if exp_name == "ddpm":
            exp_name_str = f"{exp_name_str}_{eps_seed_list[0]}"

        exp_root_path = dataset_root_path / f"{num_inference_step}" / f"{guidance_scale}_{inv_guidance_scale}" / f"{num_zig_zag_per_step}" / exp_name_str
    else:
        exp_root_path = dataset_root_path / f"{num_inference_step}" / f"{guidance_scale}_{inv_guidance_scale}" / f"{num_zig_zag_per_step}" / exp_time_str
    
    # ---------= [Prepare Task Cfg] =---------
    cfg_dict = {
        "global_variable": {
            "exp_name": exp_name, 
            "start_time": start_time, 
            "device": device, 
            "seed": seed
        }, 

        "pipeline": {
            "pipeline_type": pipeline_type, 
            "pipeline_path": pipeline_path, 
            "pipeline_torch_dtype": pipeline_torch_dtype, 
            "pipeline_variant": pipeline_variant, 

            "pipeline_category_name": pipeline_category_name, 
            "pipeline_type_name": pipeline_type_name, 

            "scheduler_type": scheduler_type
        }, 

        "init_latent": {
            "init_latent_root_path": init_latent_root_path, 
            "init_latent_seed": None
        }, 

        # "eps": {
        #     "eps_root_path": eps_root_path, 
        #     "eps_seed_list": []
        # }, 

        # "eta_list": [], 

        "sample": {
            # "prompt": prompt, 
            "prompt_2": prompt_2, 
            "negative_prompt": negative_prompt, 
            "negative_prompt_2": negative_prompt_2, 

            "height": height, 
            "width": width, 

            "num_inference_step": num_inference_step, 
            "inference_step_minus_one": inference_step_minus_one, 

            "guidance_scale": guidance_scale
        }, 

        "z_sampling": {
            "inv_guidance_scale": inv_guidance_scale, 
            "max_timestep_idx": z_sampling_max_timestep_idx, 
            "num_zig_zag_per_step": num_zig_zag_per_step
        }, 

        "task": {
            "num_sample_per_prompt": num_sample_per_prompt, 
            "batch_size": batch_size
        }, 

        "save_sample": {
            "save_sample_root_path": save_sample_root_path, 
            "pipeline_type_name": pipeline_type_name, 
            # "folder_name": folder_name, 
        }
    }

    # ---------= [Do Sampling] =---------
    def implement_batch(
        batch_idx: int, 
        true_batch_size: int, 

        arg_dict: Optional[Dict] = None
    ) -> List[torch.Tensor]:
        # ---------= [Prepare Batch Prompt List] =---------
        sample_idx_st = batch_idx * batch_size
        sample_id_ed_plus_one = sample_idx_st + true_batch_size

        # batch_prompt_tuple_list.shape = (true_batch_size, 2)
        batch_prompt_tuple_list = prompt_tuple_list[sample_idx_st: sample_id_ed_plus_one]

        logger(
            f"[Batch {batch_idx}] batch_prompt_tuple_list: {batch_prompt_tuple_list}"
        )

        # batch_prompt_list.shape = (true_batch_size, )
        batch_prompt_list = [
            prompt_tuple[1] \
                for prompt_tuple in batch_prompt_tuple_list
        ]

        # ---------= [Prepare Batch Latent List] =---------
        init_latent_seed_list = arg_dict["init_latent_seed_list"]

        init_latent_list = [
            get_latent(
                shape = (4, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in init_latent_seed_list
        ]

        # batch_init_latent_list.shape = (true_batch_size, 4, latent_height, latent_width)
        batch_init_latent_list = torch.stack(init_latent_list)

        # ---------= [Z-Sampling] =---------
        batch_latent_list = pipeline.z_sampling_forward(
            prompt = batch_prompt_list, 

            guidance_scale = guidance_scale, 
            num_inference_steps = num_inference_step, 

            latents = batch_init_latent_list, 

            max_timestep_idx = z_sampling_max_timestep_idx, 
            num_zig_zag_per_step = num_zig_zag_per_step, 
            inv_guidance_scale = inv_guidance_scale, 
        )

        # ---------= [Clean Up] =---------
        del batch_prompt_list
        del init_latent_list, batch_init_latent_list
        gc.collect()
        torch.cuda.empty_cache()

        # `implement_batch()` done
        return batch_latent_list


    for batch_idx in tqdm(
        range(num_batch), 

        desc = "[Z-Sampling]"
    ):
        # ---------= [Sample Batch] =---------
        if (batch_idx < num_batch - 1) or (num_sample % batch_size == 0):
            true_batch_size = batch_size
        else:
            true_batch_size = num_sample % batch_size
        
        sample_idx_st = batch_idx * batch_size

        # ---------= [Prepare Init Latent Seed List] =---------
        arg_dict = {}

        init_latent_seed_list_ = []

        if init_latent_random:
            init_latent_seed_list_ = [
                random.randint(init_latent_seed_st, init_latent_seed_ed) \
                    for _ in range(true_batch_size)
            ]
        else:
            for sample_idx in range(true_batch_size):
                true_sample_idx = sample_idx_st + sample_idx
                prompt_true_sample_idx = true_sample_idx % num_sample_per_prompt

                init_latent_seed_list_.append(
                    init_latent_seed_list[prompt_true_sample_idx]
                )

                # `for sample_idx` done
                pass

        arg_dict["init_latent_seed_list"] = init_latent_seed_list_
        
        # ---------= [Do Sampling] =---------
        batch_img_latent_list = implement_batch(
            batch_idx = batch_idx, 
            true_batch_size = true_batch_size, 

            arg_dict = arg_dict
        )

        # ---------= [Sample Sample] =---------
        batch_img_pil_list = img_latent_to_pil(
            img_latent_list = batch_img_latent_list, 
            pipeline = pipeline, 

            batch_size = vae_decode_batch_size
        )

        with cf.ThreadPoolExecutor(
            max_workers = concurrent_max_worker
        ) as executor:
            for sample_idx in tqdm(
                range(true_batch_size), 
                
                desc = f"[Saving Sample]"
            ):
                true_sample_idx = sample_idx_st + sample_idx
                prompt_true_sample_idx = true_sample_idx % num_sample_per_prompt
            
                prompt_idx, prompt = prompt_tuple_list[true_sample_idx]
                
                folder_name = folder_name_list[prompt_idx]
                
                folder_root_path = exp_root_path / folder_name
                cfg_root_path = folder_root_path / "cfg"
                png_root_path = folder_root_path / "png"

                # ---------= [Save Task Cfg] =---------
                cfg_dict["init_latent"]["init_latent_seed"] = arg_dict["init_latent_seed_list"][sample_idx]
                cfg_dict["sample"]["prompt"] = prompt
                cfg_dict["save_sample"]["folder_name"] = folder_name  
                
                cfg_dict = convert_numpy_type_to_native_type(cfg_dict)
                save_yaml(
                    cfg_dict, 

                    yaml_root_path = cfg_root_path, 
                    yaml_filename = f"{prompt_true_sample_idx}.yaml"
                )

                # ---------= [Save Sample] =---------
                param_dict = {
                    "pil": batch_img_pil_list[sample_idx], 

                    "png_root_path": png_root_path, 
                    "png_filename": f"{prompt_true_sample_idx}.png"
                }
                
                future = executor.submit(
                    save_pil_as_png, 
                    **param_dict
                )

                try:
                    future_res = future.result()
                except Exception as e:
                    logger(
                        f"`save_pil_as_png()` throws an exception: `{e}`. ", 
                        log_type = "error"
                    )

                # ---------= [Clean Up] =---------        
                del param_dict
                gc.collect()
                torch.cuda.empty_cache()

                # `for sample_idx` done
                pass
        
        # ---------= [Clean Up] =---------
        del arg_dict["init_latent_seed_list"]
        del arg_dict
        del batch_img_latent_list
        del batch_img_pil_list
        gc.collect()
        torch.cuda.empty_cache()

        # goto `for batch_idx`
        pass

    # `run_z_sampling_implement()` done
    pass

def run_z_sampling(
    cfg: DictConfig
):
    run_z_sampling_implement(cfg)

    # `run_z_sampling()` done
    pass
