import regex
import string

import datasets

import evaluate
exact_match = evaluate.load("exact_match")
from .task import register, Task

class equation(Task):
    VERSION = 0
    DATASET_PATH = "json"
    DATASET_NAME = None

    cache_dir = "./data/equation/cache"
    train_files = None
    test_files = None

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
       
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train"
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" 
        )

        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs

    def validation_docs(self):
        return self.dataset["validation"]

    def doc_to_text(self, doc):
        return f"input: {doc['input']}\noutput:"


    def doc_to_target(self, doc):
        return doc["output"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n"]})
        return continuation

    def _normalize_answer(self, text):
        # strip whitespace
        if len(text) > 0 and text[0] == " ":
            # print(f"text =={text}==")
            text = text.strip()

        return text

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation = self._normalize_answer(results[0])
        answers = doc["output"]

        # print(f"continuation:  =={continuation}==")
        # print(f"answers: =={answers}==")

        preds = continuation.split(" ")
        refs = answers.split(" ")

        # Ensure both lists are of the same length by appending empty strings or take subset
        if len(refs) > len(preds):
            preds.extend([""] * (len(refs) - len(preds)))
        elif len(preds) > len(refs):
            preds = preds[:len(refs)]  # Slicing preds to match the length of refs
        
        results = exact_match.compute(references=refs, predictions=preds)
        
        return {"acc": results['exact_match']}

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "acc": True,
        }


@register('upper')
class upper(equation):
    train_files = 'data/equation/upper_twoSum/upper_train.json'
    test_files = 'data/equation/upper_twoSum/upper_test.json'


@register('twoSum')
class twoSum(equation):
    train_files = 'data/equation/upper_twoSum/two_sum_train.json'
    test_files = 'data/equation/upper_twoSum/two_sum_test.json'

@register('upper_twoSum')
class upper_twoSum(equation):
    train_files = 'data/equation/upper_twoSum/upper_twoSum_train.json'
    test_files = 'data/equation/upper_twoSum/upper_twoSum_test.json'
    task_1_files = 'data/equation/upper_twoSum/upper_train.json'
    task_2_files = 'data/equation/upper_twoSum/two_sum_train.json'
    task_3_files = 'data/swap1/swap_train.json'
    def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
        self.download(data_dir, cache_dir, download_mode)
        self._training_docs = None
        self._fewshot_docs = None
        self._task1_training_docs = None
        self._task2_training_docs = None
        self._task3_training_docs = None

    def download(self, data_dir=None, cache_dir=None, download_mode=None):
        testset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.test_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" 
        )

        trainset = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.train_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" 
        )

        task_1_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_1_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" 
        )

        task_2_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_2_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" 
        )
        task_3_set = datasets.load_dataset(
            path=self.DATASET_PATH,
            name=self.DATASET_NAME,
            data_files=self.task_3_files,
            cache_dir=cache_dir,
            download_mode=download_mode,
            split = "train" 
        )

        self.dataset = datasets.DatasetDict({
            "train": trainset,
            "validation": testset,
            "task1": task_1_set,
            "task2": task_2_set,
            "task3": task_3_set
        })

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return self._training_docs
    

    def validation_docs(self):
        return self.dataset["validation"]
    
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2)
        rnd.shuffle(retval)
        return retval

@register('upper_twoSum_compose_incontext')
class upper_twoSum_compose_incontext(upper_twoSum):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('upper_twoSum_compose_incontext_tag')
class upper_twoSum_compose_incontext_tag(upper_twoSum):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        "(simple1)" + self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        "(simple2)" + self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        "(compose)" + self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = "(compose)" + self.doc_to_text(doc)

        return description + labeled_examples + example
@register('upper_twoSum_compose_incontext_re1')
class upper_twoSum_compose_incontext_re1(upper_twoSum_compose_incontext):
    train_files = 'data/equation/upper_twoSum/upper_twoSum_train.json'
    test_files = 'data/equation/upper_twoSum/upper_test.json'
    task_1_files = 'data/equation/upper_twoSum/upper_train.json'
    task_2_files = 'data/equation/upper_twoSum/two_sum_train.json'
    def doc_to_target1(self, doc):
        return doc["output"]
    def doc_to_target2(self, doc):
        return doc["output"]
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('upper_twoSum_compose_incontext_irrl')
class upper_twoSum_compose_incontext_irrl(upper_twoSum_compose_incontext):
    train_files = 'data/equation/upper_twoSum/upper_twoSum_train.json'
    test_files = 'data/equation/upper_twoSum/upper_test.json'
    task_1_files = 'data/equation/upper_twoSum/upper_train.json'
    task_2_files = 'data/equation/upper_twoSum/two_sum_train.json'
    task_3_files = 'data/swap1/swap_train.json'
    def doc_to_target1(self, doc):
        return doc["output"]
    def doc_to_target2(self, doc):
        return doc["output"]
    def doc_to_target3(self, doc):
        return doc["output"]
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task3_training_docs is None:
                self._task3_training_docs = list(self.dataset["task3"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task3_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('upper_twoSum_compose_incontext_distribution')
class upper_twoSum_compose_incontext_distribution(upper_twoSum_compose_incontext):
    train_files = 'data/equation/upper_twoSum/upper_twoSum_train.json'
    test_files = 'data/equation/upper_twoSum/upper_twoSum_test.json'
    task_1_files = 'data/equation/upper_twoSum/upper_train.json'
    task_2_files = 'data/equation/upper_twoSum/two_sum_train.json'
    def doc_to_target1(self, doc):
        a,b = doc["input"][3:-2].split("@")
        return a + b[1:-1]
    def doc_to_target2(self, doc):
        return doc["output"].lower()
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('upper_twoSum_compose_incontext_cot')
class upper_twoSum_compose_incontext_cot(upper_twoSum_compose_incontext):
    train_files = 'data/equation/upper_twoSum/upper_twoSum_train.json'
    test_files = 'data/equation/upper_twoSum/upper_twoSum_test.json'
    task_1_files = 'data/equation/upper_twoSum/upper_train.json'
    task_2_files = 'data/equation/upper_twoSum/two_sum_train.json'
    def doc_to_target_rand(self, doc):
        return "*( " + doc["output"].lower() + " )"
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + "->" + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + "->" + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + "->" +  self.doc_to_target_rand(doc) + "->" + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example
@register('upper_twoSum_compose_incontext_expcot')
class upper_twoSum_compose_incontext_expcot(upper_twoSum_compose_incontext):
    train_files = 'data/equation/upper_twoSum/upper_twoSum_train.json'
    test_files = 'data/equation/upper_twoSum/upper_twoSum_test.json'
    task_1_files = 'data/equation/upper_twoSum/upper_train.json'
    task_2_files = 'data/equation/upper_twoSum/two_sum_train.json'
    def doc_to_target_rand(self, doc):
        return "*( " + doc["output"].lower() + " )"
    def doc_pure_text(self, doc):
        return doc["input"] + "\n"
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        #rnd.shuffle(retval)

        return retval
    def doc_to_query(self, doc):
        return "step1: " + doc["input"] + "\nstep2:"
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        "step1: " + self.doc_pure_text(doc) + " step2: " + self.doc_to_target(doc) + "\nstep3: ???"
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        "step1:???\nstep2:" + self.doc_pure_text(doc) + "step3:" + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        "step1: " + self.doc_pure_text(doc) + "step2: " +  self.doc_to_target_rand(doc) + "\nstep3: " + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_query(doc)

        return description + labeled_examples + example
# ================================================

class mod(equation):
    train_files = 'data/equation/mod_twoSum/mod_train.json'
    test_files = 'data/equation/mod_twoSum/mod_test.json'


class twoSumPlus(equation):
    train_files = 'data/equation/mod_twoSum/two_sum_train.json'
    test_files = 'data/equation/mod_twoSum/two_sum_test.json'

@register('mod_twoSum')
class mod_twoSum(upper_twoSum):
    train_files = 'data/equation/mod_twoSum/mod_twoSum_train.json'
    test_files = 'data/equation/mod_twoSum/mod_twoSum_test.json'
    task_1_files = 'data/equation/mod_twoSum/mod_train.json'
    task_2_files = 'data/equation/mod_twoSum/two_sum_train.json'

    def fewshot_examples(self, k, k2, kc, rnd):
        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2)
        rnd.shuffle(retval)
        return retval
    
    def process_results(self, doc, results):
        continuation = self._normalize_answer(results[0])
        answers = doc["output"]

        # print(f"continuation:  =={continuation}==")
        # print(f"answers: =={answers}==")

        preds = continuation.split(" ")
        refs = answers.split(" ")
        # print("preds: ", preds)
        # print("refs: ", refs)

        # assert False

        
        return {"acc":  float(preds[0] in refs)}

@register('mod_twoSum_compose_incontext')
class mod_twoSum_compose_incontext(mod_twoSum):
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example

@register('mod_twoSum_compose_incontext_re1')
class mod_twoSum_compose_incontext_re1(mod_twoSum):
    train_files = 'data/equation/mod_twoSum/mod_twoSum_train.json'
    test_files = 'data/equation/mod_twoSum/mod_test.json'
    task_1_files = 'data/equation/mod_twoSum/mod_train.json'
    task_2_files = 'data/equation/mod_twoSum/two_sum_train.json'
    def doc_to_target1(self, doc):
        return doc["output"]
    def doc_to_target2(self, doc):
        return doc["output"]
    def fewshot_examples(self, k, k2, kc, rnd):
        if self._training_docs is None:
            self._training_docs = list(self.training_docs())

        compose = rnd.sample(self._training_docs, kc)

        if self._task1_training_docs is None:
                self._task1_training_docs = list(self.dataset["task1"])
        
        if self._task2_training_docs is None:
                self._task2_training_docs = list(self.dataset["task2"])
        
        retval = rnd.sample(self._task1_training_docs, k) + rnd.sample(self._task2_training_docs, k2) + compose
        rnd.shuffle(retval)

        return retval
    def fewshot_context(
        self, doc, kc, num_fewshot, num_fewshot2, rnd=None, description=None
    ):
        assert (
            rnd is not None
        ), "A `random.Random` generator argument must be provided to `rnd`"
        description = description + "\n\n" if description else ""
        
        if num_fewshot == 0:
            labeled_examples = ""
        else:
            fewshotex = self.fewshot_examples(k = num_fewshot,k2 = num_fewshot2 , kc = kc, rnd = rnd)

            prompt_list = [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[:num_fewshot]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot:num_fewshot+num_fewshot2]
                    ] + \
                    [
                        self.doc_to_text(doc) + self.doc_to_target(doc)
                        for doc in fewshotex[num_fewshot+num_fewshot2:]
                    ]
            
            rnd.shuffle(prompt_list)
            
            #print("fewshotex: ", fewshotex)
            labeled_examples = (
                "\n\n".join(
                    prompt_list
                )
                + "\n\n"
            )

        example = self.doc_to_text(doc)

        return description + labeled_examples + example