from util.logger import logger

from typing import List

from omegaconf import DictConfig

import gc

from pathlib import Path

from tqdm.auto import tqdm

import torch

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.image_util import load_img_path
from util.yaml_util import (
    load_yaml, 
    convert_numpy_type_to_native_type, 
    save_yaml
)
from util.metric_util.lpips_util import cal_lpips

from diffusion_OT_MCTS.reward_model.util import get_reward_model


def cal_reward_robustness_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    # ---------= [LPIPS Cfg Path] =---------
    logger(f"[LPIPS Cfg Path] Loading started. ")

    lpips_cfg_path = get_true_value(cfg["task"]["lpips_cfg_path"])

    logger(f"    lpips_cfg_path: {lpips_cfg_path}")

    logger(
        f"[LPIPS Cfg Path] Loading finished. "
        "\n"
    )

    # ---------= [Eps] =---------
    logger(f"[Eps] Loading started. ")

    eps = get_true_value(cfg["task"]["eps"])

    logger(f"    eps: {eps}")

    logger(
        f"[Eps] Loading finished. "
        "\n"
    )

    # ---------= [Path] =---------
    logger(f"[Path] Loading started. ")

    exp_root_path = get_true_value(cfg["task"]["path"]["exp_root_path"])
    exp_name_list = get_true_value(cfg["task"]["path"]["exp_name_list"])
    folder_name_list = get_true_value(cfg["task"]["path"]["folder_name_list"])

    logger(f"    exp_root_path: {exp_root_path}")
    logger(f"    exp_name_list: {exp_name_list}")
    logger(f"    folder_name_list: {folder_name_list}")

    logger(
        f"[Path] Loading finished. "
        "\n"
    )

    # ---------= [Num Sample] =---------
    logger(f"[Num Sample] Loading started. ")

    num_sample = get_true_value(cfg["task"]["num_sample"])

    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Num Sample] Loading finished. "
        "\n"
    )

    # ---------= [Prepare LPIPS Net Type] =---------
    if lpips_cfg_path is None:
        lpips_cfg_path = "./config/metric_model/lpips.yaml"
    lpips_cfg_path = Path(lpips_cfg_path)
    
    lpips_model_dict = load_yaml(lpips_cfg_path)

    lpips_net_type = lpips_model_dict["lpips"]["net_type"]
    lpips_batch_size = lpips_model_dict["lpips"]["batch_size"]

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Path] =---------
    exp_root_path = Path(exp_root_path)

    if exp_name_list == "all":
        exp_name_list = [
            exp_path.stem \
                for exp_path in exp_root_path.iterdir() \
                    if exp_path.stem != "_metric"
        ]
    elif not isinstance(exp_name_list, list):
        exp_name_list = [exp_name_list]

    # ---------= [Cal Reward Robustness] =---------
    sample_robustness_list = []

    lpips_model = None
    

    def implement_folder(
        exp_path: Path, 
        folder_name: str
    ):
        nonlocal lpips_model

        for sample_idx in range(num_sample):
            folder_path = exp_path / folder_name
            res_yaml_path = folder_path / "result" / f"{sample_idx}.yaml"

            if not res_yaml_path.is_file():
                logger(
                    f"res_yaml_path: {res_yaml_path}", 
                    log_type = "error"
                )

                breakpoint()
            
            img_reward_list = load_yaml(res_yaml_path)["best_merged_reward_list"]

            num_img = len(img_reward_list)

            if num_img >= 2:
                png_root_path = folder_path / "png" / f"{sample_idx}"
                png_path_list = [
                    png_root_path / f"{sample_idx}_{img_idx}.png" \
                        for img_idx in range(num_img)
                ]
                img_pil_list = [
                    load_img_path(png_path) \
                        for png_path in png_path_list
                ]

                img_pil_list_1 = []
                img_pil_list_2 = []

                for img_idx in range(1, num_img):
                    img_pil_list_1.append(
                        img_pil_list[img_idx - 1]
                    )
                    img_pil_list_2.append(
                        img_pil_list[img_idx]
                    )

                    # goto `for img_idx`
                    pass

                (
                    lpips_list, 
                    lpips_model
                ) = cal_lpips(
                    img_pil_list_1 = img_pil_list_1, 
                    img_pil_list_2 = img_pil_list_2, 

                    lpips_net_type = lpips_net_type, 
                    batch_size = lpips_batch_size, 

                    device = device, 

                    model = lpips_model, 

                    disable_tqdm = True
                )

                for img_idx in range(1, num_img):
                    sample_robustness \
                        = abs(img_reward_list[img_idx] - img_reward_list[img_idx - 1]) \
                            / (lpips_list[img_idx - 1] + eps)

                    sample_robustness_list.append(sample_robustness)

                    # goto `for img_idx`
                    pass
                
            # ---------= [Clean Up] =---------
            del img_reward_list
            if num_img >= 2:
                del png_path_list
                del img_pil_list
                del img_pil_list_1, img_pil_list_2
                del lpips_list

            # goto `for sample_idx`
            pass

        # avg_best_reward /= num_sample

        # avg_nfe_dynamics /= num_sample
        # avg_nfe_intermediate /= num_sample
        # avg_nfe_final /= num_sample

        # avg_time_cost /= num_sample

        # `implement_folder()` done
        pass


    for exp_name in exp_name_list:
        exp_path = exp_root_path / exp_name
        
        if folder_name_list == "all":
            tmp_folder_name_list = [
                folder_path.stem \
                    for folder_path in exp_path.iterdir() \
                        if folder_path.stem != "_metric"
            ]
        else:
            tmp_folder_name_list = folder_name_list

        if not isinstance(tmp_folder_name_list, list):
            tmp_folder_name_list = [tmp_folder_name_list]

        for folder_name in tmp_folder_name_list:
            implement_folder(
                exp_path = exp_path, 
                folder_name = folder_name
            )
            
            # goto `for folder_name`
            pass

        # ---------= [Clean Up] =---------
        del tmp_folder_name_list
        gc.collect()

        # goto `for exp_name`
        pass

    num_sample = len(sample_robustness_list)
    if num_sample <= 0:
        logger(
            "No valid sample. ", 
            log_type = "error"
        )

        breakpoint()

    reward_robustness = sum(sample_robustness_list) / num_sample

    logger(
        f"reward_robustness: {reward_robustness:.4f}"
    )

    # ---------= [Clean Up] =---------
    del sample_robustness_list
    del lpips_model
    gc.collect()
    torch.cuda.empty_cache()
    
    # `cal_reward_robustness_implement()` done
    pass


def cal_reward_robustness(
        
    cfg: DictConfig
):
    cal_reward_robustness_implement(cfg)

    # `cal_reward_robustness()` done
    pass
