from util.logger import logger

from typing import List

from omegaconf import DictConfig

import concurrent.futures as cf

from pathlib import Path

from tqdm.auto import tqdm

import gc

import torch

from diffusers.utils.torch_utils import randn_tensor

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.image_util import load_img_path
from util.yaml_util import (
    load_yaml, 
    convert_numpy_type_to_native_type, 
    save_yaml
)

import util.metric_util.lpips_util as lpips_util


def cal_mean_pairwise_distance_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    # ---------= [LPIPS Cfg Path] =---------
    logger(f"[LPIPS Cfg Path] Loading started. ")

    lpips_cfg_path = get_true_value(cfg["task"]["lpips_cfg_path"])

    logger(f"    lpips_cfg_path: {lpips_cfg_path}")

    logger(
        f"[LPIPS Cfg Path] Loading finished. "
        "\n"
    )

    # ---------= [Setting Root Path] =---------
    logger(f"[Setting Root Path] Loading started. ")

    setting_root_path = get_true_value(cfg["task"]["setting_root_path"])

    logger(f"    setting_root_path: {setting_root_path}")

    logger(
        f"[Setting Root Path] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare LPIPS Net Type] =---------
    lpips_cfg_path = Path(lpips_cfg_path)
    lpips_model_dict = load_yaml(lpips_cfg_path)

    lpips_net_type = lpips_model_dict["lpips"]["net_type"]
    lpips_batch_size = lpips_model_dict["lpips"]["batch_size"]

    # ---------= [Prepare Folder Root Path List] =---------
    setting_root_path = Path(setting_root_path)

    folder_root_path_list = [
        folder_root_path \
            for folder_root_path in setting_root_path.iterdir() \
                if folder_root_path.stem != "_metric"
    ]

    # ---------= [Compute MPD per Prompt] =---------
    res_dict = {}

    # img_pil_list = []

    lpips_model = None

    sum_lpips_score = 0.0

    for folder_root_path in tqdm(
        folder_root_path_list, 

        desc = "[Compute MPD per Prompt]"
    ):
        png_root_path = folder_root_path / "png"

        png_path_list = [
            png_path \
                for png_path in png_root_path.iterdir() \
                    if png_path.suffix == ".png"
        ]

        tmp_img_pil_list = [
            load_img_path(png_path) \
                for png_path in png_path_list
        ]

        (
            tmp_lpips_score, 
            _lpips_model
        ) = lpips_util.cal_mean_pairwise_distance_lpips(
            img_pil_list = tmp_img_pil_list, 

            lpips_net_type = lpips_net_type, 
            batch_size = lpips_batch_size, 

            device = device, 

            model = lpips_model, 

            disable_tqdm = True
        )

        if lpips_model is None:
            lpips_model = _lpips_model

        folder_name = folder_root_path.stem
        res_dict[folder_name] = tmp_lpips_score

        sum_lpips_score += tmp_lpips_score

        # img_pil_list += tmp_img_pil_list

        # --------= [Clean Up] =---------
        del png_path_list
        del tmp_img_pil_list
        gc.collect()
        
        # goto `for folder_root_path`
        pass

    avg_lpips_score = sum_lpips_score / len(folder_root_path_list)
    res_dict["mpd_prompt"] = avg_lpips_score

    logger(f"mpd_prompt: {avg_lpips_score}")

    # # ---------= [Compute MPD All] =---------
    # tmp_lpips_score, _lpips_model = lpips_util.cal_mean_pairwise_distance_lpips(
    #     img_pil_list = img_pil_list, 

    #     lpips_net_type = lpips_net_type, 
    #     batch_size = lpips_batch_size, 

    #     device = device, 

    #     model = lpips_model, 

    #     disable_tqdm = False
    # )

    # res_dict["mpd_all"] = tmp_lpips_score

    # logger(f"mpd_all: {avg_lpips_score}")

    # ---------= [Save Resuls] =---------
    metric_root_path = setting_root_path / "_metric"

    res_dict = convert_numpy_type_to_native_type(res_dict)
    save_yaml(
        res_dict, 

        yaml_root_path = metric_root_path, 
        yaml_filename = "mpd_lpips.yaml"
    )

    # ---------= [Clean Up] =---------
    del lpips_model
    gc.collect()
    torch.cuda.empty_cache()

    # `cal_mean_pairwise_distance_implement()` done
    pass


def cal_mean_pairwise_distance(
    cfg: DictConfig
):
    cal_mean_pairwise_distance_implement(cfg)

    # `cal_mean_pairwise_distance()` done
    pass
