from util.logger import logger

from typing import Optional, Union, List

from pathlib import Path

import gc

from util.yaml_util import load_yaml
from util.json_util import load_json

from .prompt_manager import PromptManager


class HumanPreferenceDataset_v2(PromptManager):
    def __init__(
        self, 

        # ---------= [Category] =---------
        category_name_list: Optional[Union[str, List[str]]] = "anime", 

        cfg_yaml_path: Optional[str] = None
    ):
        super().__init__()

        if isinstance(category_name_list, str):
            category_name_list = [category_name_list]
        
        if cfg_yaml_path is None:
            self._cfg_yaml_path = Path("./config/dataset/hpd_v2.yaml")
        else:
            self._cfg_yaml_path = Path(cfg_yaml_path)

        self._dataset_cfg_dict = load_yaml(self._cfg_yaml_path)
        self._category_path_dict = self._dataset_cfg_dict["category_path_dict"]

        for category_name in category_name_list:
            if category_name not in self._category_path_dict.keys():
                raise ValueError(
                    f"Unsupported `category_name`, got `{category_name}`. "
                )
        
        self._category_name_list = category_name_list

        # `__init__()` done
        pass


    def load_prompt_list(
        self
    ):
        """
        Func:
            Load the `self.prompt_list`. 
        """
        
        prompt_idx_list_dict = self._dataset_cfg_dict["prompt_idx_list_dict"]

        for category_name in self._category_name_list:
            prompt_json_path = self._category_path_dict[category_name]
            prompt_list = load_json(prompt_json_path)

            prompt_idx_list = prompt_idx_list_dict[category_name]
            
            selected_prompt_list = []
            if prompt_idx_list == "all":
                selected_prompt_list = prompt_list
            elif (prompt_idx_list is not None) and (len(prompt_idx_list) > 0):
                selected_prompt_list = [
                    prompt_list[prompt_idx] \
                        for prompt_idx in prompt_idx_list
                ]

            self.prompt_list += selected_prompt_list

            # ---------= [Clean Up] =---------
            del prompt_list
            gc.collect()

            # goto `for category_name`
            pass

        # `load_prompt_list()` done
        pass
