import regex
import string
import random
import datasets
import evaluate
exact_match = evaluate.load("exact_match")

from .task import register, Task

@register('names2_upper')
class upper(Task):
    VERSION = 0
    DATASET_PATH = "json"
    DATASET_NAME = None

    cache_dir = "./data/upper_plusOne/cache"
    train_files = 'data/upper_plusOne/upper_train.json'
    test_files = 'data/upper_plusOne/upper_test.json'

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
       
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )

        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs

    def validation_docs(self):
        return self.dataset["validation"]

    def doc_to_text(self, doc):
        return f"input: {doc['input']}\noutput: "
        
    def doc_to_text1(self, doc):
        pair = doc['input'].split(" ")
        a, b = pair
        #inp = f"input: * {a} * {b}\noutput: "
        inp = f"input: * {doc['input']}\noutput: "
        return inp
    def doc_to_text2(self, doc):
        pair = doc['input'].split(" ")
        a, b = pair
        #inp = f"input: * {a} * {b}\noutput: "
        inp = f"input: * {doc['input']}\noutput: "
        return inp
    def doc_to_target(self, doc):
        return doc["output"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n"]})
        return continuation

    def _normalize_answer(self, text):
        # strip whitespace
        if len(text) > 0 and text[0] == " ":
            # print(f"text =={text}==")
            text = text.strip()

        return text

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation = self._normalize_answer(results[0])
        answers = doc["output"]

        # print(f"continuation:  =={continuation}==")
        # print(f"answers: =={answers}==")

        preds = continuation.split(" ")
        refs = answers.split(" ")

        # Ensure both lists are of the same length by appending empty strings or take subset
        if len(refs) > len(preds):
            preds.extend([""] * (len(refs) - len(preds)))
        elif len(preds) > len(refs):
            preds = preds[:len(refs)]  # Slicing preds to match the length of refs
        
        results = exact_match.compute(references=refs, predictions=preds)
        
        return {"acc": results['exact_match']}

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "acc": True,
        }


@register('plusOne')
class plusOne(upper):
    train_files = 'data/upper_plusOne/plusOne_train.json'
    test_files = 'data/upper_plusOne/plusOne_test.json'

@register('upper_plusOne')
class upper_plusOne(upper):
    train_files = 'data/upper_plusOne/upper_plusOne_train.json'
    test_files = 'data/upper_plusOne/upper_plusOne_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'#data/twoSum_reverse_cipher/
    task_3_files = 'data/twoSum_reverse_cipher/reverse_train.json'

    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        
        self.download(data_dir, cache_dir, download_mode)
        self._training_docs = None
        self._fewshot_docs = None
        self._task1_training_docs = None
        self._task2_training_docs = None
        self._task3_training_docs = None

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )

        task_1_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_1_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )

        task_2_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_2_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )
        task_3_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_3_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )
        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset,
            "task1": task_1_set,
            "task2": task_2_set,
            "task3": task_3_set
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs
    

    def validation_docs(self):
        return self.dataset["validation"]
    
    def fewshot_examples(self, k, k2, kc, rnd):
        # if composed_in_context:
        #     if self._training_docs is None:
        #         self._training_docs = list(self.training_docs())

        #     return rnd.sample(self._training_docs, k)
        
        # k1 = k//2
        # k2 = k - k1

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2)
        rnd.shuffle(retval)
        return retval
@register('upper_plusOne_com_incontext')
class upper_plusOne_compose_incontext(upper_plusOne):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        rnd.shuffle(retval)

        return retval

@register('plusOne_upper')
class plusOne_upper(upper_plusOne):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/plusOne_upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'   
  

@register('plusOne_upper_compose_incontext')
class plusOne_upper_compose_incontext(plusOne_upper):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
        
@register('plusOne_upper_compose_incontext_tag')
class plusOne_upper_compose_incontext_tag(plusOne_upper):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        "(simple1)" + self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        "(simple2)" + self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        "(compose)" + self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = "(compose)" + self.doc_to_text(doc)

        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_re1')
class plusOne_upper_compose_incontext_re1(plusOne_upper_compose_incontext):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'
    def doc_to_target1(self, doc):
        return doc["output"]
    def doc_to_target2(self, doc):
        return doc["output"]
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_irrl')
class plusOne_upper_compose_incontext_irrl(plusOne_upper_compose_incontext):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'
    task_3_files = 'data/twoSum_reverse_cipher/reverse_train.json'
    def doc_to_target1(self, doc):
        return doc["output"]
    def doc_to_target2(self, doc):
        return doc["output"]
    def doc_to_target3(self, doc):
        return doc["output"]
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task3_training_docs is None:
                self._task3_training_docs = list(self.dataset["task3"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task3_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_distribution')
class plusOne_upper_compose_incontext_distribution(plusOne_upper_compose_incontext):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/plusOne_upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'
    def doc_to_target1(self, doc):
        a, b = doc["output"].split(' ')
        c, d = doc["input"].split(' ')
        ret = a + " " + d
        return ret
    def doc_to_target2(self, doc):
        a, b = doc["output"].split(' ')
        c, d = doc["input"].split(' ')
        ret = b + " " + c
        return ret
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_cot')
class plusOne_upper_compose_incontext_cot(plusOne_upper_compose_incontext):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/plusOne_upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'
    def doc_to_target_rand(self, doc):
        a, b = doc["output"].split(' ')
        c, d = doc["input"].split(' ')
        ret = a + " " + d
        return ret
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + "->" + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + "->" + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + "->" +  self.doc_to_target_rand(doc) + "->" + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_expcot')
class plusOne_upper_compose_incontext_expcot(plusOne_upper_compose_incontext):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/plusOne_upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'
    def doc_to_target_rand(self, doc):
        a, b = doc["output"].split(' ')
        c, d = doc["input"].split(' ')
        ret = c + " " + b
        return ret
    def doc_pure_text(self, doc):
        return doc["input"] + "\n"
    def doc_to_query(self, doc):
        return "step1: " + doc["input"] + "\nstep2:"
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        "step1: " + self.doc_pure_text(doc) + "step2: " + self.doc_to_target(doc) + "\nstep3: ???"
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        "step1:???\nstep2: " + self.doc_pure_text(doc) + "step3: " + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        "step1: " + self.doc_pure_text(doc) + "step2: " +  self.doc_to_target_rand(doc) + "\nstep3: " + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_query(doc)

        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_special')
class plusOne_upper_compose_incontext_special(plusOne_upper):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        labeled_examples = "input: 8048\noutput: 8049\n\ninput: 5947\noutput: 5948\n\ninput: 4389\noutput: 4390\n\ninput: basset\noutput: BASSET\n\ninput: 3665\noutput: 3666\n\ninput: 1526\noutput: 1527\n\ninput: 2764\noutput: 2765\n\ninput: shih\noutput: SHIH\n\ninput: oscilloscope\noutput: OSCILLOSCOPE\n\ninput: 2799 teller\noutput: 2800 TELLER\n\ninput: 8264\noutput: 8265\n\ninput: cauliflower\noutput: CAULIFLOWER\n\ninput: 7409\noutput: 7410\n\ninput: 8157 unicycle\noutput: 8158 UNICYCLE\n\ninput: 9356\noutput: 9357\n\ninput: safety\noutput: SAFETY\n\ninput: baby\noutput: BABY\n\ninput: adult\noutput: ADULT\n\ninput: 2742\noutput: 2743\n\ninput: harmonica\noutput: HARMONICA\n\ninput: eagle\noutput: EAGLE\n\ninput: 4014 groenendael\noutput: 4015 GROENENDAEL\n\ninput: 2768 valley\noutput: 2769 VALLEY\n\ninput: sheepdog\noutput: SHEEPDOG\n\ninput: 8321 vizsla\noutput: 8322 VIZSLA\n\n"

        example = self.doc_to_text(doc)

        # print("example", [example])


        # print("=====")
        return description + labeled_examples + example
@register('plusOne_upper_compose_incontext_re1_special')
class plusOne_upper_compose_incontext_re1_special(plusOne_upper):
    train_files = 'data/upper_plusOne/plusOne_upper_train.json'
    test_files = 'data/upper_plusOne/upper_test.json'
    task_1_files = 'data/upper_plusOne/upper_train.json'
    task_2_files = 'data/upper_plusOne/plusOne_train.json'
    def doc_to_text1(self, doc):
        pair = doc['input'].split(" ")
        a, b = pair
        #inp = f"input: * {a} * {b}\noutput: "
        inp = f"input: * {doc['input']}\noutput: "
        return inp

    def doc_to_target1(self, doc):
        return doc["output"]
    
    def doc_to_text2(self, doc):
        pair = doc['input'].split(" ")
        a, b = pair
        #inp = f"input: {a} # {b} #\noutput: "
        inp = f"input: {doc['input']} #\noutput: "
        return inp
    
    def doc_to_target2(self, doc):
        pair = doc['input'].split(" ")
        inp = " ".join(pair)
        a, b = pair
        oup = " ".join([b,a])
        return oup
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        labeled_examples = "input: 8048\noutput: 8049\n\ninput: 5947\noutput: 5948\n\ninput: 4389\noutput: 4390\n\ninput: basset\noutput: BASSET\n\ninput: 3665\noutput: 3666\n\ninput: 1526\noutput: 1527\n\ninput: 2764\noutput: 2765\n\ninput: shih\noutput: SHIH\n\ninput: oscilloscope\noutput: OSCILLOSCOPE\n\ninput: 2799 teller\noutput: 2800 TELLER\n\ninput: 8264\noutput: 8265\n\ninput: cauliflower\noutput: CAULIFLOWER\n\ninput: 7409\noutput: 7410\n\ninput: 8157 unicycle\noutput: 8158 UNICYCLE\n\ninput: 9356\noutput: 9357\n\ninput: safety\noutput: SAFETY\n\ninput: baby\noutput: BABY\n\ninput: adult\noutput: ADULT\n\ninput: 2742\noutput: 2743\n\ninput: harmonica\noutput: HARMONICA\n\ninput: eagle\noutput: EAGLE\n\ninput: 4014 groenendael\noutput: 4015 GROENENDAEL\n\ninput: 2768 valley\noutput: 2769 VALLEY\n\ninput: sheepdog\noutput: SHEEPDOG\n\ninput: 8321 vizsla\noutput: 8322 VIZSLA\n\n"

        example = self.doc_to_text(doc)

        # print("example", [example])


        # print("=====")
        return description + labeled_examples + example
