import regex
import string

import datasets

import evaluate
exact_match = evaluate.load("exact_match")

from .task import register, Task

@register('swap')
class swap(Task):
    VERSION = 0
    DATASET_PATH = "json"
    DATASET_NAME = None

    cache_dir = "./data/swap/cache"
    train_files = 'data/swap/swap_train.json'
    test_files = 'data/swap/swap_test.json'

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
       
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # this dataset does not have train/val/test split, only the dataset object
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" #  this dataset does not have train/val/test split, only the dataset object
        )

        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs

    def validation_docs(self):
        return self.dataset["validation"]

    def doc_to_text(self, doc):
        return f"input: {doc['input']}\noutput: "


    def doc_to_target(self, doc):
        return doc["output"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n"]})
        return continuation

    def _normalize_answer(self, text):
        # strip whitespace
        if len(text) > 0 and text[0] == " ":
            # print(f"text =={text}==")
            text = text.strip()

        return text

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation = self._normalize_answer(results[0])
        answers = doc["output"]

        # print(f"continuation:  =={continuation}==")
        # print(f"answers: =={answers}==")

        preds = continuation.split(" ")
        refs = answers.split(" ")

        # Ensure both lists are of the same length by appending empty strings or take subset
        if len(refs) > len(preds):
            preds.extend([""] * (len(refs) - len(preds)))
        elif len(preds) > len(refs):
            preds = preds[:len(refs)]  # Slicing preds to match the length of refs
        
        results = exact_match.compute(references=refs, predictions=preds)
        
        return {"acc": results['exact_match']}

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "acc": True,
        }


@register('names_upper')
class upper(swap):
    train_files = 'data/swap/upper_train.json'
    test_files = 'data/swap/upper_test.json'


@register('upper_swap')
class upper_swap(swap):
    train_files = 'data/swap/upper_swap_train.json'
    test_files = 'data/swap/upper_swap_test.json'
    task_1_files = 'data/swap/upper_train.json'
    task_2_files = 'data/swap/swap_train.json'

    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        
        self.download(data_dir, cache_dir, download_mode)
        self._training_docs = None
        self._fewshot_docs = None
        self._task1_training_docs = None
        self._task2_training_docs = None

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # : this dataset does not have train/val/test split, only the dataset object
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # : this dataset does not have train/val/test split, only the dataset object
        )

        task_1_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_1_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # : this dataset does not have train/val/test split, only the dataset object
        )

        task_2_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_2_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" # : this dataset does not have train/val/test split, only the dataset object
        )

        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset,
            "task1": task_1_set,
            "task2": task_2_set
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs
    

    def validation_docs(self):
        return self.dataset["validation"]
    
    def fewshot_examples(self, k, rnd):
        # if composed_in_context:
        #     if self._training_docs is None:
        #         self._training_docs = list(self.training_docs())

        #     return rnd.sample(self._training_docs, k)
        
        # k1 = k//2
        # k2 = k - k1

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k)
        rnd.shuffle(retval)
        return retval

@register('upper_swap_compose_incontext')
class upper_swap_compose_incontext(upper_swap):
    def fewshot_examples(self, k, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, k)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k) + compose
        rnd.shuffle(retval)

        return retval


@register('swap_upper')
class swap_upper(upper_swap):
    train_files = 'data/swap/swap_upper_train.json'
    test_files = 'data/swap/swap_upper_test.json'
    task_1_files = 'data/swap/upper_train.json'
    task_2_files = 'data/swap/swap_train.json'   

  

@register('swap_upper_compose_incontext')
class swap_upper_compose_incontext(swap_upper):
    def fewshot_examples(self, k, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, k)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k) + compose
        rnd.shuffle(retval)

        return retval
