from util.logger import logger

from typing import List

from omegaconf import DictConfig

import concurrent.futures as cf

from pathlib import Path

from tqdm.auto import tqdm

import torch

from diffusers.utils.torch_utils import randn_tensor

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.pkl_util import save_pkl
from util.torch_util import get_generator


def generate_init_latent_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    # ---------= [Seed] =---------
    logger(f"[Seed] Loading started. ")

    seed_st = get_true_value(cfg["task"]["seed"]["seed_st"])

    logger(f"    seed_st: {seed_st}")

    logger(
        f"[Seed] Loading finished. "
        "\n"
    )

    # ---------= [Sample] =---------
    logger(f"[Sample] Loading started. ")
    
    torch_dtype = get_true_value(cfg["task"]["sample"]["torch_dtype"])

    logger(f"    torch_dtype: {torch_dtype}")

    height = get_true_value(cfg["task"]["sample"]["height"])
    width = get_true_value(cfg["task"]["sample"]["width"])
    down_sampling_ratio = get_true_value(cfg["task"]["sample"]["down_sampling_ratio"])

    logger(f"    height: {height}")
    logger(f"    width: {width}")
    logger(f"    down_sampling_ratio: {down_sampling_ratio}")

    logger(
        f"[Sample] Loading finished. "
        "\n"
    )

    # ---------= [Task] =---------
    logger(f"[Task] Loading started. ")
    
    num_sample = get_true_value(cfg["task"]["task"]["num_sample"])
    batch_size = get_true_value(cfg["task"]["task"]["batch_size"])

    logger(f"    num_sample: {num_sample}")
    logger(f"    batch_size: {batch_size}")

    logger(
        f"[Task] Loading finished. "
        "\n"
    )

    # ---------= [Save Sample] =---------
    logger(f"[Save Sample] Loading started. ")
    
    save_sample_root_path = get_true_value(cfg["task"]["save_sample"]["save_sample_root_path"])

    logger(f"    save_sample_root_path: {save_sample_root_path}")

    logger(
        f"[Save Sample] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    seed_list = [
        seed_st + i \
            for i in range(num_sample)
    ]

    torch_dtype = get_attr("torch", torch_dtype)

    num_channel = 4
    height //= down_sampling_ratio
    width //= down_sampling_ratio

    num_batch = (num_sample + batch_size - 1) // batch_size
    
    save_sample_root_path = Path(save_sample_root_path)
    save_sample_root_path = save_sample_root_path / f"{num_channel}_{height}_{width}"

    # ---------= [Generate Init Latent] =---------
    def implement_batch(
        batch_idx: int, 
        true_batch_size: int
    ) -> List[torch.Tensor]:
        sample_idx_st = batch_idx * batch_size
        sample_id_ed_plus_one = sample_idx_st + true_batch_size

        generator_list = [
            get_generator(
                seed = seed_list[i], 
                device = device
            ) \
                for i in range(sample_idx_st, sample_id_ed_plus_one)
        ]

        # batch_latent_list.shape = (true_batch_size, num_channel, height, width)
        batch_latent_list = randn_tensor(
            shape = (true_batch_size, num_channel, height, width), 
            generator = generator_list, 

            device = torch.device(device), 
            dtype = torch_dtype
        )

        batch_latent_list = batch_latent_list.detach().cpu()
        
        batch_latent_list = batch_latent_list.chunk(
            chunks = true_batch_size, 
            dim = 0
        )
        
        batch_latent_list = list(batch_latent_list)
        batch_latent_list = [
            latent.squeeze(0) \
                for latent in batch_latent_list
        ]

        # `implement_batch()` done
        return batch_latent_list
    
    
    latent_list = []

    for batch_idx in tqdm(
        range(num_batch), 

        desc = "[Generating Init Latent]"
    ):
        if (batch_idx < num_batch - 1) or (num_sample % batch_size == 0):
            true_batch_size = batch_size
        else:
            true_batch_size = num_sample % batch_size

        batch_latent_list = implement_batch(
            batch_idx = batch_idx, 
            true_batch_size = true_batch_size
        )
        
        latent_list += batch_latent_list
        
        # goto `for batch_idx`
        pass
    
    # ---------= [Save Samples] =---------
    with cf.ThreadPoolExecutor(
        max_workers = concurrent_max_worker
    ) as executor:
        for sample_idx in tqdm(
            range(num_sample), 

            desc = "[Saving Sample]"
        ):
            param_dict = {
                "var": latent_list[sample_idx], 

                "pkl_root_path": save_sample_root_path, 
                "pkl_filename": f"{sample_idx}.pkl"
            }

            future = executor.submit(
                save_pkl, 
                **param_dict
            )

            try:
                future_res = future.result()
            except Exception as e:
                logger(
                    f"`save_pkl()` throws an exception: `{e}`. ", 
                    log_type = "error"
                )

            # goto `for sample_idx`
            pass

    # `generate_init_latent_implement()` done
    pass


def generate_init_latent(
    cfg: DictConfig
):
    generate_init_latent_implement(cfg)

    # `generate_init_latent()` done
    pass
