from util.logger import logger

from typing import List

from omegaconf import DictConfig

import concurrent.futures as cf

from pathlib import Path

from tqdm.auto import tqdm

import gc

import torch

from diffusers.utils.torch_utils import randn_tensor

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.image_util import (
    load_img_path, 
    load_img_folder_as_pil_list
)
from util.yaml_util import (
    load_yaml, 
    convert_numpy_type_to_native_type, 
    save_yaml
)

from util.metric_util.fid_util import cal_fid


def cal_mscoco_fid_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    # ---------= [FID Cfg Path] =---------
    logger(f"[FID Cfg Path] Loading started. ")

    fid_cfg_path = get_true_value(cfg["task"]["fid_cfg_path"])

    logger(f"    fid_cfg_path: {fid_cfg_path}")

    logger(
        f"[FID Cfg Path] Loading finished. "
        "\n"
    )

    # ---------= [Path] =---------
    logger(f"[Path] Loading started. ")

    exp_root_path = get_true_value(cfg["task"]["exp_root_path"])
    exp_name_list = get_true_value(cfg["task"]["exp_name_list"])
    folder_name_list = get_true_value(cfg["task"]["folder_name_list"])
    num_sample = get_true_value(cfg["task"]["num_sample"])

    logger(f"    exp_root_path: {exp_root_path}")
    logger(f"    exp_name_list: {exp_name_list}")
    logger(f"    folder_name_list: {folder_name_list}")
    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Path] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare FID Model] =---------
    if fid_cfg_path is None:
        fid_cfg_path = "./config/metric_model/fid.yaml"
    fid_cfg_path = Path(fid_cfg_path)
    
    fid_model_dict = load_yaml(fid_cfg_path)

    fid_feature_dim = fid_model_dict["fid"]["feature_dim"]
    fid_batch_size = fid_model_dict["fid"]["batch_size"]
    fid_num_worker = fid_model_dict["fid"]["num_worker"]

    # ---------= [Prepare Path] =---------
    cfg_yaml_path = Path("./config/dataset/mscoco_2014_5k_test.yaml")
    cfg_dict = load_yaml(cfg_yaml_path)

    dataset_root_path = Path(cfg_dict["dataset_root_path"])
    img_folder_name = cfg_dict["img_folder_name"]

    img_root_path = dataset_root_path / img_folder_name
    ref_img_pil_list = load_img_folder_as_pil_list(
        img_root_path = img_root_path
    )

    exp_root_path = Path(exp_root_path)

    if exp_name_list == "all":
        exp_name_list = [
            exp_path.name \
                for exp_path in exp_root_path.iterdir() \
                    if exp_path.name != "_metric"
        ]
    elif not isinstance(exp_name_list, list):
        exp_name_list = [exp_name_list]

    # ---------= [Cal FID] =---------
    fid_list = []

    fid_model = None

    def _implement_exp(
        exp_name: str
    ) -> float:
        nonlocal fid_model

        setting_root_path = exp_root_path / exp_name

        if folder_name_list == "all":
            tmp_folder_name_list = [
                folder_path.name \
                    for folder_path in setting_root_path.iterdir() \
                        if folder_path.name != "_metric"
            ]
        else:
            tmp_folder_name_list = folder_name_list
        
        if not isinstance(tmp_folder_name_list, list):
            tmp_folder_name_list = [tmp_folder_name_list]

        png_root_path_list = [
            setting_root_path / folder_name / "png" \
                for folder_name in tmp_folder_name_list
        ]

        img_pil_list = []

        for png_root_path in tqdm(
            png_root_path_list, 

            desc = f"[Folder]"
        ):
            for sample_idx in range(num_sample):
                png_path = png_root_path / f"{sample_idx}.png"

                
                img_pil = load_img_path(png_path)

                img_pil_list.append(img_pil)

                # goto `for sample_idx`
                pass

            # goto `for png_root_path`
            pass

        (
            fid, 
            fid_model
        ) = cal_fid(
            img_pil_list_1 = ref_img_pil_list, 
            img_pil_list_2 = img_pil_list, 

            batch_size = fid_batch_size, 
            feature_dim = fid_feature_dim, 
            num_worker = fid_num_worker, 

            device = device, 

            model = fid_model
        )

        # ---------= [Clean Up] =---------
        del tmp_folder_name_list
        del png_root_path_list
        del img_pil_list
        gc.collect()

        # `_implement_exp()` done
        return (
            fid, 
            fid_model
        )


    for exp_name in tqdm(
        exp_name_list, 
        desc = f"[Exp]"
    ):
        (
            fid, 
            fid_model
        ) =_implement_exp(
            exp_name = exp_name
        )

        fid_list.append(fid)

        logger(
            f"{exp_name}: {fid}"
        )

        # goto `for exp_name`
        pass
    
    num_sample = len(fid_list)
    if num_sample <= 0:
        logger(
            "No valid sample. ", 
            log_type = "error"
        )

        breakpoint()

    avg_fid = sum(fid_list) / num_sample

    fid_list_str = ", ".join(
        [
            f"{fid:.4f}" \
                for fid in fid_list
        ]
    )

    avg_fid_str = f"{avg_fid:.4f}"

    logger(
        f"fid_list: {fid_list_str}"
    )
    logger(
        f"avg_fid: {avg_fid_str}"
    )

    # ---------= [Clean Up] =---------
    del fid_list
    del fid_model
    gc.collect()
    torch.cuda.empty_cache()

    # `cal_mscoco_fid_implement()` done
    pass


def cal_mscoco_fid(
    cfg: DictConfig
):
    cal_mscoco_fid_implement(cfg)

    # `cal_mscoco_fid()` done
    pass
