from util.logger import logger

from typing import List

from omegaconf import DictConfig

import concurrent.futures as cf

from pathlib import Path

from tqdm.auto import tqdm

import torch

import gc

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import (
    load_yaml, 
    convert_numpy_type_to_native_type, 
    save_yaml
)


def add_optimized_prompt_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    # ---------= [Source Setting Root Path] =---------
    logger(f"[Source Setting Root Path] Loading started. ")
    
    src_setting_root_path = get_true_value(cfg["task"]["src_setting_root_path"])

    logger(f"    src_setting_root_path: {src_setting_root_path}")

    logger(
        f"[Source Setting Root Path] Loading finished. "
        "\n"
    )

    # ---------= [Destination Path] =---------
    logger(f"[Destination Path] Loading started. ")
    
    dst_exp_root_path = get_true_value(cfg["task"]["dst_path"]["exp_root_path"])
    dst_setting_name_list = get_true_value(cfg["task"]["dst_path"]["setting_name_list"])

    logger(f"    dst_exp_root_path: {dst_exp_root_path}")
    logger(f"    dst_setting_name_list: {dst_setting_name_list}")

    logger(
        f"[Destination Path] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    src_setting_root_path = Path(src_setting_root_path)

    dst_exp_root_path = Path(dst_exp_root_path)

    # ---------= [Prepare Source] =---------
    folder_path_list = list(src_setting_root_path.iterdir())
    folder_name_list = [
        folder_path.name \
            for folder_path in folder_path_list
    ]

    optimized_prompt_dict = {}

    for folder_name in folder_name_list:
        folder_path = src_setting_root_path / folder_name
        cfg_yaml_path = folder_path / "cfg.yaml"
        cfg_dict = load_yaml(cfg_yaml_path)

        optimized_prompt = cfg_dict["sample"]["optimized_prompt"]

        optimized_prompt_dict[folder_name] = optimized_prompt

        # ---------= [Clean Up] =---------
        del cfg_dict
        gc.collect()

        # goto `for folder_name`
        pass

    # ---------= [Add Optimized Prompt] =---------
    def implement_setting(
        dst_setting_root_path: Path
    ):
        folder_path_list = list(dst_setting_root_path.iterdir())
        folder_name_list = [
            folder_path.name \
                for folder_path in folder_path_list
        ]

        for folder_name in folder_name_list:
            folder_path = dst_setting_root_path / folder_name
            cfg_yaml_path = folder_path / "cfg.yaml"
            cfg_dict = load_yaml(cfg_yaml_path)

            optimized_prompt = optimized_prompt_dict[folder_name]

            if "optimized_prompt" not in cfg_dict.keys():
                cfg_dict["sample"]["optimized_prompt"] = optimized_prompt

            cfg_dict = convert_numpy_type_to_native_type(cfg_dict)
            
            save_yaml(
                cfg_dict, 

                yaml_root_path = folder_path, 
                yaml_filename = "cfg.yaml"
            )

            # ---------= [Clean Up] =---------
            del cfg_dict
            gc.collect()

            # goto `for folder_name`
            pass

        # ---------= [Clean Up] =---------
        del folder_path_list, folder_name_list
        gc.collect()

        # `implement_setting()` done
        pass


    for setting_name in dst_setting_name_list:
        dst_setting_root_path = dst_exp_root_path / setting_name

        implement_setting(
            dst_setting_root_path = dst_setting_root_path
        )

        # goto `for setting_name`
        pass

    # `add_optimized_prompt_implement()` done
    pass


def add_optimized_prompt(
    cfg: DictConfig
):
    add_optimized_prompt_implement(cfg)

    # `add_optimized_prompt()` done
    pass
