"""Dataset."""

import json
import random
import datasets

logger = datasets.logging.get_logger(__name__)
llama2_prompt = """<s> Input: {user_message}
Output: 
"""

class DataConfig(datasets.BuilderConfig):
    """
    Config dataset load procedure.

    Args:
        data_dir: task data dir, which contains the corresponding dataset dirs
    """

    def __init__(
            self,
            *args,
            data_dir=None,
            **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.data_dir = data_dir


# TODO, few-shot, 需要 load 的时候就将值存好，放在 "Examples" 里面
class DataInstructions(datasets.GeneratorBasedBuilder):
    """InstructData Dataset."""
    def _info(self):
        return datasets.DatasetInfo(
            features=datasets.Features(
                {
                    "Instance": {
                        "id": datasets.Value("string"),
                        "sentence": datasets.Value("string"),
                        "label": datasets.Value("string"),
                        "instruction": datasets.Value("string"),
                    }
                }
            ),
            supervised_keys=None
        )

    def _split_generators(self, dl_manager):
        """Returns SplitGenerators."""
        if self.config.data_dir is None:
            logger.error("Please provide right input: data_dir!")

        # split dir save datasets
        # task config to specify train,dev,test
        split_dir = self.config.data_dir

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                gen_kwargs={
                    "path": split_dir + '/train.json',
                    "subset": "train"
                }),
            datasets.SplitGenerator(
                name=datasets.Split.VALIDATION,
                gen_kwargs={
                    "path": split_dir + '/eval.json',
                    "subset": "dev"
                }),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                gen_kwargs={
                    "path": split_dir + '/test.json',
                    "subset": "test"
                }),
        ]

    def _load_dataset(self, dataset_path):
        with open(dataset_path, encoding="utf-8") as task_f:
            s = task_f.read()
            instances = json.loads(s)

        return instances

    def load_dataset(self, dataset_path):

        data = self._load_dataset(dataset_path)
        dataset_name = str(dataset_path) if type(dataset_path) is not str else dataset_path
        print("dataset_name: \n", dataset_name)
        print(list(data.keys()))

        sample_template = {}

        for idx, instance in enumerate(data['Instances']):
            example = sample_template.copy()
            instruction = llama2_prompt.format(
                system_prompt='',
                user_message=instance['input']
            )
    
            example["Instance"] = {
                "id": str(idx),
                "sentence": instance['input'],
                "label": instance["output"],
                "instruction": instruction
            }

            yield example


    def _generate_examples(self, path=None, subset=None):
        """Yields examples."""
        logger.info(f"Generating tasks from = {path}")

        # load dataset
        idx = -1
        instances = []
        for sample in self.load_dataset(path):
            idx += 1
            instances.append(sample)
            yield f"{path}##{idx}", sample
