from util.logger import logger

from typing import List

from omegaconf import DictConfig

import gc

from pathlib import Path

from tqdm.auto import tqdm

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.image_util import load_img_path
from util.yaml_util import (
    load_yaml, 
    convert_numpy_type_to_native_type, 
    save_yaml
)

from diffusion_OT_MCTS.reward_model.util import get_reward_model


def cal_final_reward_ours_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    # ---------= [Reward Model] =---------
    logger(f"[Reward Model] Loading started. ")

    reward_model_type = get_true_value(cfg["task"]["reward_model"]["reward_model_type"])
    cal_final_reward_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_final_reward_batch_size"])

    logger(f"    reward_model_type: {reward_model_type}")
    logger(f"    cal_final_reward_batch_size: {cal_final_reward_batch_size}")

    logger(
        f"[Reward Model] Loading finished. "
        "\n"
    )

    # ---------= [Path] =---------
    logger(f"[Path] Loading started. ")

    setting_root_path = get_true_value(cfg["task"]["path"]["setting_root_path"])
    folder_name_list = get_true_value(cfg["task"]["path"]["folder_name_list"])

    logger(f"    setting_root_path: {setting_root_path}")
    logger(f"    folder_name_list: {folder_name_list}")

    logger(
        f"[Path] Loading finished. "
        "\n"
    )

    # ---------= [Num Sample] =---------
    logger(f"[Num Sample] Loading started. ")

    num_sample = get_true_value(cfg["task"]["num_sample"])

    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Num Sample] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Path] =---------
    setting_root_path = Path(setting_root_path)

    if folder_name_list == "all":
        folder_name_list = [
            folder_path.stem \
                for folder_path in setting_root_path.iterdir() \
                    if folder_path.stem != "_metric"
        ]
    elif not isinstance(folder_name_list, list):
        folder_name_list = [folder_name_list]

    # ---------= [Prepare Reward Model Type] =---------
    reward_model_type_str = reward_model_type
    if reward_model_type == "color_channel_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/color_channel_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        target_channel_idx = reward_model_cfg_dict["color_channel_reward"]["target_channel_idx"]
        reward_model_type_str = f"{reward_model_type_str}-{target_channel_idx}"
    elif reward_model_type == "laplacian_var_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/laplacian_var_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        tsfm_to_gray = reward_model_cfg_dict["laplacian_var_reward"]["tsfm_to_gray"]
        if tsfm_to_gray:
            reward_model_type_str = f"{reward_model_type_str}-gray"
        else:
            reward_model_type_str = f"{reward_model_type_str}-rgb"
    elif reward_model_type == "compressibility_hps_v2":
        reward_model_cfg_yaml_path = Path("./config/reward_model/compressibility_hps_v2.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        lam = reward_model_cfg_dict["compressibility_hps_v2"]["lam"]
        reward_model_type_str = f"{reward_model_type_str}-{lam}"

    # ---------= [Prepare Prompt List & PIL List] =---------
    prompt_tuple_list = []  # (prompt, folder_name, sample_idx)
    img_pil_list = []

    for folder_name in folder_name_list:
        folder_path = setting_root_path / folder_name

        cfg_path = folder_path / "cfg.yaml"

        if not cfg_path.is_file():
            continue

        cfg_dict = load_yaml(cfg_path)

        if "optimized_prompt" in cfg_dict["sample"].keys():
            prompt = cfg_dict["sample"]["optimized_prompt"]
        else:
            prompt = cfg_dict["sample"]["prompt"]

        sample_root_path = folder_path / "png"

        for sample_idx in range(num_sample):
            png_root_path = sample_root_path / f"{sample_idx}"

            png_path_list = list(
                png_root_path.iterdir()
            )
            png_path_list.sort(
                key = lambda png_path: int(
                    png_path.stem.split('_')[-1]
                )
            )

            for png_path in png_path_list:
                prompt_tuple_list.append(
                    (prompt, folder_name, sample_idx)
                )
                
                img_pil = load_img_path(png_path)
                if not png_path.is_file():
                    breakpoint()
                img_pil_list.append(img_pil)

                # goto `for png_path`
                pass
            
            # goto `for sample_idx`
            pass

        # ---------= [Clean Up] =---------
        del cfg_dict
        gc.collect()

        # goto `for folder_name`
        pass

    num_img = len(img_pil_list)

    # ---------= [Prepare Reward Model] =---------
    reward_shape = (1, )
    torch_dtype = "float16"

    reward_model = get_reward_model(
        reward_model_type = reward_model_type, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 
        reward_dtype = torch_dtype, 

        # ---------= [Parallel] =---------
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        device = device, 

        vae_decode_batch_size = vae_decode_batch_size
    )

    # ---------= [Cal Final Reward] =---------
    prompt_list = [
        prompt_tuple[0] \
            for prompt_tuple in prompt_tuple_list
    ]

    final_reward_list = reward_model.cal_final_reward(
        img_pil_list = img_pil_list, 
        prompt_list = prompt_list
    )
    final_reward_list = final_reward_list.reshape(num_img, )

    # ---------= [Save Result] =---------
    res_dict = {}

    l = 0

    last_r = 0

    with tqdm(
        total = num_img, 
        desc = "[Saving Results]"
    ) as bar:
        while l < num_img:
            prompt = prompt_tuple_list[l][0]
            folder_name = prompt_tuple_list[l][1]
            sample_idx = prompt_tuple_list[l][2]

            r = l
            while (r + 1 < num_img) \
                and (prompt_tuple_list[r + 1][1] == folder_name) \
                and (prompt_tuple_list[r + 1][2] == sample_idx):

                r += 1
            
            batch_final_reward_list = [
                final_reward.item() \
                    for final_reward in final_reward_list[l: (r + 1)]
            ]

            res_dict = {
                reward_model_type: batch_final_reward_list
            }

            folder_path = setting_root_path / folder_name
            save_metric_root_path = folder_path / "_metric" / reward_model_type_str

            res_dict = convert_numpy_type_to_native_type(res_dict)

            save_yaml(
                res_dict, 

                yaml_root_path = save_metric_root_path, 
                yaml_filename = f"{sample_idx}.yaml"
            )

            l = r + 1

            # ---------= [Update Tqdm] =---------
            del_r = r - last_r
            bar.update(del_r)

            last_r = r

            # ---------= [Clean Up] =---------
            del batch_final_reward_list
            gc.collect()

            # goto `while l`
            pass

    # `cal_final_reward_ours_implement()` done
    pass


def cal_final_reward_ours(
        
    cfg: DictConfig
):
    cal_final_reward_ours_implement(cfg)

    # `cal_final_reward_ours()` done
    pass
