from util.logger import logger

from typing import Optional, Union, List, Tuple

from pathlib import Path

import gc

from util.yaml_util import load_yaml
from util.json_util import load_json

from .prompt_manager import PromptManager


class DrawBench(PromptManager):
    def __init__(
        self, 

        # ---------= [Category] =---------
        category_name_list: Optional[Union[str, List[str]]] = "anime", 

        cfg_yaml_path: Optional[str] = None
    ):
        super().__init__()

        if isinstance(category_name_list, str):
            category_name_list = [category_name_list]
        
        if cfg_yaml_path is None:
            self._cfg_yaml_path = Path("./config/dataset/draw_bench.yaml")
        else:
            self._cfg_yaml_path = Path(cfg_yaml_path)

        self._dataset_cfg_dict = load_yaml(self._cfg_yaml_path)

        # `__init__()` done
        pass


    def _get_category_st_idx_length_dict(
        self, 

        category_num_prompt_list: List[Tuple[str, int]]
    ):
        """
        Func:
            Prepare `(st_idx, length)` for every category. 
        
        Ret:
            `category_st_idx_length_dict` (`Dict`): The dictionary of `(st_idx, length)` for every category. 
        """

        category_st_idx_length_dict = {}

        cur_idx = 0

        for (category_name, num_prompt) in category_num_prompt_list:
            category_st_idx_length_dict[category_name] = (
                cur_idx, 
                num_prompt
            )

            cur_idx += num_prompt

            # goto `for (category_name, num_prompt)`
            pass

        # `_get_category_st_idx_length_dict()` done
        return category_st_idx_length_dict


    def load_prompt_list(
        self
    ):
        """
        Func:
            Load the `self.prompt_list`. 
        """
        
        prompt_list_json_path = self._dataset_cfg_dict["prompt_list_json_path"]
        prompt_list = load_json(prompt_list_json_path)

        prompt_idx_list_dict = self._dataset_cfg_dict["prompt_idx_list_dict"]

        category_num_prompt_list = self._dataset_cfg_dict["category_num_prompt_list"]
        category_st_idx_length_dict = self._get_category_st_idx_length_dict(
            category_num_prompt_list = category_num_prompt_list
        )

        for (category_name, prompt_idx_list) in prompt_idx_list_dict.items():
            if category_name not in category_st_idx_length_dict.keys():
                raise ValueError(
                    f"Unsupported `category_name`, got `{category_name}`. "
                )

            (
                st_idx, 
                length
            ) = category_st_idx_length_dict[category_name]

            selected_prompt_list = []
            if prompt_idx_list == "all":
                selected_prompt_list = prompt_list[st_idx: st_idx + length]
            elif (prompt_idx_list is not None) and (len(prompt_idx_list) > 0):
                selected_prompt_list = [
                    prompt_list[st_idx + prompt_idx] \
                        for prompt_idx in prompt_idx_list
                ]

            self.prompt_list += selected_prompt_list

            # goto `for (category_name, prompt_idx_list)`
            pass

        # ---------= [Clean Up] =----------
        del prompt_list
        del prompt_idx_list_dict
        del category_num_prompt_list
        del category_st_idx_length_dict
        gc.collect()

        # `load_prompt_list()` done
        pass
