from util.logger import logger

from typing import List

from omegaconf import DictConfig

import gc

from pathlib import Path

from tqdm.auto import tqdm

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.image_util import load_img_path
from util.yaml_util import (
    load_yaml, 
    convert_numpy_type_to_native_type, 
    save_yaml
)

from diffusion_OT_MCTS.reward_model.util import get_reward_model


def cal_final_reward_baseline_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    # ---------= [Reward Model] =---------
    logger(f"[Reward Model] Loading started. ")

    reward_model_type = get_true_value(cfg["task"]["reward_model"]["reward_model_type"])
    cal_final_reward_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_final_reward_batch_size"])

    logger(f"    reward_model_type: {reward_model_type}")
    logger(f"    cal_final_reward_batch_size: {cal_final_reward_batch_size}")

    logger(
        f"[Reward Model] Loading finished. "
        "\n"
    )

    # ---------= [Path] =---------
    logger(f"[Path] Loading started. ")

    folder_root_path = get_true_value(cfg["task"]["path"]["folder_root_path"])
    folder_name_list = get_true_value(cfg["task"]["path"]["folder_name_list"])
    num_sample = get_true_value(cfg["task"]["path"]["num_sample"])

    logger(f"    folder_root_path: {folder_root_path}")
    logger(f"    folder_name_list: {folder_name_list}")
    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Path] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Path] =---------
    folder_root_path = Path(folder_root_path)

    if folder_name_list == "all":
        folder_name_list = [
            folder_path.stem \
                for folder_path in folder_root_path.iterdir() \
                    if folder_path.stem != "_metric"
        ]
    elif not isinstance(folder_name_list, list):
        folder_name_list = [folder_name_list]

    # ---------= [Prepare Reward Model Type] =---------
    reward_model_type_str = reward_model_type
    if reward_model_type == "color_channel_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/color_channel_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        target_channel_idx = reward_model_cfg_dict["color_channel_reward"]["target_channel_idx"]
        reward_model_type_str = f"{reward_model_type_str}-{target_channel_idx}"
    elif reward_model_type == "laplacian_var_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/laplacian_var_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        tsfm_to_gray = reward_model_cfg_dict["laplacian_var_reward"]["tsfm_to_gray"]
        if tsfm_to_gray:
            reward_model_type_str = f"{reward_model_type_str}-gray"
        else:
            reward_model_type_str = f"{reward_model_type_str}-rgb"
    elif reward_model_type == "compressibility_hps_v2":
        reward_model_cfg_yaml_path = Path("./config/reward_model/compressibility_hps_v2.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        lam = reward_model_cfg_dict["compressibility_hps_v2"]["lam"]
        reward_model_type_str = f"{reward_model_type_str}-{lam}"

    # ---------= [Prepare Prompt List & PIL List] =---------
    prompt_tuple_list = []  # (prompt, folder_name)
    img_pil_list = []

    for folder_name in folder_name_list:
        folder_path = folder_root_path / folder_name

        cfg_root_path = folder_path / "cfg"
        png_root_path = folder_path / "png"

        for sample_idx in range(num_sample):
            cfg_path = cfg_root_path / f"{sample_idx}.yaml"
            png_path = png_root_path / f"{sample_idx}.png"

            if not (cfg_path.is_file() and png_path.is_file()):
                continue
            
            prompt = load_yaml(cfg_path)["sample"]["prompt"]
            img_pil = load_img_path(png_path)
            
            prompt_tuple_list.append(
                (prompt, folder_name)
            )
            img_pil_list.append(img_pil)

            # goto `for sample_idx`
            pass

        # goto `for folder_name`
        pass

    num_img = len(img_pil_list)

    # ---------= [Prepare Reward Model] =---------
    reward_shape = (1, )
    torch_dtype = "float16"

    reward_model = get_reward_model(
        reward_model_type = reward_model_type, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 
        reward_dtype = torch_dtype, 

        # ---------= [Parallel] =---------
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        device = device, 

        vae_decode_batch_size = vae_decode_batch_size
    )

    # ---------= [Cal Final Reward] =---------
    prompt_list = [
        prompt_tuple[0] \
            for prompt_tuple in prompt_tuple_list
    ]

    final_reward_list = reward_model.cal_final_reward(
        img_pil_list = img_pil_list, 
        prompt_list = prompt_list
    )
    final_reward_list = final_reward_list.reshape(num_img, )

    # ---------= [Save Result] =---------
    res_dict = {}

    l = 0
    
    # sum_final_reward = 0.0
    # tot_num_sample = 0

    last_r = 0

    with tqdm(
        total = num_img, 
        desc = "[Saving Results]"
    ) as bar:
        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            batch_final_reward_list = [
                final_reward.item() \
                    for final_reward in final_reward_list[l: (r + 1)]
            ]
            
            # batch_sum_final_reward = sum(batch_final_reward_list)
            # sum_final_reward += batch_sum_final_reward

            # length = r - l + 1
            # batch_avg_final_reward = batch_sum_final_reward / length

            # tot_num_sample += length

            folder_name = prompt_tuple_list[l][1]
            folder_path = folder_root_path / folder_name

            # res_dict[folder_name] = batch_avg_final_reward

            res_dict = {
                reward_model_type: []
            }

            for sample_idx in range(num_sample):
                res_dict[reward_model_type].append(
                    batch_final_reward_list[sample_idx]
                )

                # goto `for sample_idx`
                pass

            res_dict = convert_numpy_type_to_native_type(res_dict)

            save_metric_root_path = folder_path / "_metric"

            save_yaml(
                res_dict, 

                yaml_root_path = save_metric_root_path, 
                yaml_filename = f"{reward_model_type_str}.yaml"
            )

            l = r + 1

            # ---------= [Update Tqdm] =---------
            del_r = r - last_r
            bar.update(del_r)

            last_r = r

            # ---------= [Clean Up] =---------
            del batch_final_reward_list
            gc.collect()

            # goto `while l`
            pass
    
    # avg_final_reward = sum_final_reward / tot_num_sample

    # logger(f"avg_final_reward: {avg_final_reward:.4f}")

    # ---------= [Save Result] =---------
    # res_dict["avg_final_reward"] = avg_final_reward

    # matric_root_path = folder_root_path / "_metric"
    # yaml_filename = f"{reward_model_type_str}.yaml"

    # res_dict = convert_numpy_type_to_native_type(res_dict)
    # save_yaml(
    #     res_dict, 

    #     yaml_root_path = matric_root_path, 
    #     yaml_filename = yaml_filename
    # )

    # `cal_final_reward_baseline_implement()` done
    pass


def cal_final_reward_baseline(
    cfg: DictConfig
):
    cal_final_reward_baseline_implement(cfg)

    # `cal_final_reward_baseline()` done
    pass
