import regex
import string

import datasets

import evaluate
exact_match = evaluate.load("exact_match")

import random

from .task import register, Task

a_level_dict = {
    "Mammals": "animal",
    "Birds": "animal",
    "Reptiles": "animal",
    "Amphibians": "animal",
    "Fish": "animal",
    "Insects": "animal",
    "Arachnids": "animal",
    "Crustaceans": "animal",
    "Mollusks": "animal",
    "Fruit": "food",
    "Vegetable": "food",
    "Meat": "food",
    "Seafood": "food",
    "Dairy": "food",
    "Grain": "food",
    "Nut_and_Seed": "food",
    "Beverage": "food",
    "Spice_and_Herb": "food",
    "Confectionery": "food",
    "Car": "vehicle",
    "Bicycle": "vehicle",
    "Motorcycle": "vehicle",
    "Truck": "vehicle",
    "Bus": "vehicle",
    "Tram": "vehicle",
    "Train": "vehicle",
    "Airplane": "vehicle",
    "Boat": "vehicle"
}

superclass_symbol = {
    "animal": "*&",
    "food": "!%",
    "vehicle": "$#"
}



__all__ = ["a_level", "b_level", "ab_level", "ab_level_compose_incontext", "a_level_symbol", "ab_level_symbol", "ab_level_compose_incontext_symbol"]

@register('a_level')
class a_level(Task):
    VERSION = 0
    DATASET_PATH = "json"
    DATASET_NAME = None

    # cache_dir = "./cache"
    train_files = './data/hierarchy/a_level_train.json'
    test_files = './data/hierarchy/a_level_test.json'

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
       
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )

        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs

    def validation_docs(self):
        return self.dataset["validation"]

    def doc_to_text(self, doc):
        return f"{doc['input']} is "


    def doc_to_target(self, doc):
        return doc["output"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n"]})
        return continuation

    def _normalize_answer(self, text):
        # strip whitespace
        if len(text) > 0 and text[0] == " ":
            # print(f"text =={text}==")
            text = text.strip()

        return text

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation = self._normalize_answer(results[0])
        answers = self.doc_to_target(doc)

        # print(f"continuation:  =={continuation}==")
        # print(f"answers: =={answers}==")

        preds = continuation.split(" ")
        refs = answers.split(" ")

        # Ensure both lists are of the same length by appending empty strings or take subset
        if len(refs) > len(preds):
            preds.extend([""] * (len(refs) - len(preds)))
        elif len(preds) > len(refs):
            preds = preds[:len(refs)]  # Slicing preds to match the length of refs

        # print(f"preds =={preds}===")
        # print(f"refs =={refs}===")

        results = exact_match.compute(references=refs, predictions=preds)
        # print(results)
        # print("done")
        
        return {"acc": results['exact_match']}

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "acc": True,
        }

@register('b_level')
class b_level(a_level):
    train_files = './data/hierarchy/b_level_train.json'
    test_files = './data/hierarchy/b_level_test.json'

@register('ab_level')
class ab_level(a_level):
    train_files = './data/hierarchy/ab_level_train.json'
    test_files = './data/hierarchy/ab_level_test.json'
    task_1_files = './data/hierarchy/a_level_train.json'
    task_2_files = './data/hierarchy/b_level_train.json'
    task_2_files_test = './data/hierarchy/b_level_test.json'


    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        
        self.download(data_dir, cache_dir, download_mode)
        self._training_docs = None
        self._fewshot_docs = None
        self._task1_training_docs = None
        self._task2_training_docs = None

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )

        task_1_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_1_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )

        task_2_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_2_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )

        task_2_set_test = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_2_files_test,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # zhuoyan: this dataset does not have train/val/test split, only the dataset object
        )


        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset,
            "task1": task_1_set,
            "task2": task_2_set,
            "task2_test": task_2_set_test,
        })

        b_level_task_docs = list(self.dataset["task2_test"])
        rnd_test_shuffle = random.Random()
        rnd_test_shuffle.seed(3407)
        rnd_test_shuffle.shuffle(b_level_task_docs)
        self.b_level_task_docs = b_level_task_docs

    def fewshot_examples(self, k, rnd, doc_id):
        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k) + [self.b_level_task_docs[doc_id]]
        rnd.shuffle(retval)
        return retval
    
    def fewshot_context(
        self, doc, num_fewshot, rnd=None, description=None, doc_id = 0
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot, rnd = rnd, doc_id=doc_id)
            
            # print("fewshotex", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )

        # print("labeled_examples", [labeled_examples])

        ### add a_level demonstrations
        b_level_category_for_doc = self.doc_to_target(self.b_level_task_docs[doc_id])
        labeled_examples += f"{b_level_category_for_doc} is {a_level_dict[b_level_category_for_doc]}" + "\n\n"

        example = self.doc_to_text(doc)

        # print("example", [example])


        # print("=====")
        return description + labeled_examples + example

@register("ab_level_compose_incontext")
class ab_level_compose_incontext(ab_level):
    def fewshot_examples(self, k, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, k)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k) + compose
        rnd.shuffle(retval)

        return compose

    def fewshot_context(
        self, doc, num_fewshot, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot, rnd = rnd)
            
            # print("fewshotex", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex
                    ]
                )
                + "\n\n"
            )

        # print("labeled_examples", [labeled_examples])

        example = self.doc_to_text(doc)

        # print("example", [example])


        # print("=====")
        return description + labeled_examples + example


### symbol
@register('a_level_symbol')
class a_level_symbol(a_level):
    def doc_to_target(self, doc):
        label = superclass_symbol.get(doc["output"]) or doc["output"]
        return label

@register('ab_level_symbol')
class ab_level_symbol(ab_level):
    def doc_to_target(self, doc):
        label = superclass_symbol.get(doc["output"]) or doc["output"]
        return label

@register('ab_level_compose_incontext_symbol')
class ab_level_compose_incontext_symbol(ab_level_compose_incontext):
    def doc_to_target(self, doc):
        label = superclass_symbol.get(doc["output"]) or doc["output"]
        return label
