import os.path
import json
from configs.preprocessor import Preprocessor
import multiprocessing as mp
import random
from datasets import Dataset, load_dataset


special_tokens = []


fewshot_examples = [1, 3, 5, 7, 9]


def load_data(input_dir, instruction, shot_count, eval_by_logits, tokenizer):

    test_set = load_dataset("ccdv/mediasum", split="test")
    examples = load_dataset("ccdv/mediasum", split="validation").select(range(shot_count))
    
    test_set = test_set.filter(lambda example: len(tokenizer(example["document"])["input_ids"]) < 512)
    # examples = examples.filter(lambda example: len(tokenizer(example["document"])["input_ids"]) < 512)
    
    preprocessor = MediaSumPreprocessor(instruction, examples, eval_by_logits, input_dir)
    preprocess = preprocessor.processor
    
    print(preprocessor.processor)
    test_set = test_set.map(preprocess, remove_columns=['document', 'summary'], num_proc=1)
    if "dialogue" in test_set[0].keys():
        test_set = test_set.remove_columns(['dialogue'])
    return test_set


class MediaSumPreprocessor(Preprocessor):

    _UNOBSERVED_ALPACA_INSTRUCTIONS = [
        "Summarize the essense of the dialogue below",
        "What is the abstract of this dialogue?",
        "You are provided with a dialoge. The task is to summarize the dialogue into a few sentences.",
        "Task: Summarize the given text.",
        "Summarize it into a short paragraph.",
        "Question: what is discussed in this dialogue? Answer in  1-2 sentences.",
        "Summarization",
        "Using your understanding of the text, summarize it into a few sentences.",
        "You are an costumer service agent. You are asked to summarize the dialogue below into a few sentences.",
        "Summarize the dialogue below into a few sentences. Try to use the words from the dialogue as much as possible"
    ]

    def __init__(self, instruction, examples, eval_by_logits, input_dir):
        super(MediaSumPreprocessor, self).__init__(instruction, examples, eval_by_logits, input_dir)
        
    def alpaca_preprocess(self, item):
        new_item = {
            "input_text": item["document"],
            "output_text": item["summary"],
        }
        return new_item

    def add_unobserved_instructions(self):
        self.instr2preprocessor["MedSum/Unobserved/1"] = self.unobserved1
        self.instr2preprocessor["MedSum/Unobserved/2"] = self.unobserved2
        self.instr2preprocessor["MedSum/Unobserved/3"] = self.unobserved3
        self.instr2preprocessor["MedSum/Unobserved/4"] = self.unobserved4
        self.instr2preprocessor["MedSum/Unobserved/5"] = self.unobserved5
        self.instr2preprocessor["MedSum/Unobserved/6"] = self.unobserved6
        self.instr2preprocessor["MedSum/Unobserved/7"] = self.unobserved7
        self.instr2preprocessor["MedSum/Unobserved/8"] = self.unobserved8
        self.instr2preprocessor["MedSum/Unobserved/9"] = self.unobserved9
        self.instr2preprocessor["MedSum/Unobserved/10"] = self.unobserved10
        
        self.instr2preprocessor["Alpaca/Observed/1"] = lambda item: self.alpaca_summarization_1(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/2"] = lambda item: self.alpaca_summarization_2(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/3"] = lambda item: self.alpaca_summarization_3(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/4"] = lambda item: self.alpaca_summarization_4(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/5"] = lambda item: self.alpaca_summarization_5(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/6"] = lambda item: self.alpaca_summarization_6(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/7"] = lambda item: self.alpaca_summarization_7(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/8"] = lambda item: self.alpaca_summarization_8(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/9"] = lambda item: self.alpaca_summarization_9(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/10"] = lambda item: self.alpaca_summarization_10(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/11"] = lambda item: self.alpaca_summarization_11(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/12"] = lambda item: self.alpaca_summarization_12(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/13"] = lambda item: self.alpaca_summarization_13(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/14"] = lambda item: self.alpaca_summarization_14(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/15"] = lambda item: self.alpaca_summarization_15(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/16"] = lambda item: self.alpaca_summarization_16(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/17"] = lambda item: self.alpaca_summarization_17(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/18"] = lambda item: self.alpaca_summarization_18(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/19"] = lambda item: self.alpaca_summarization_19(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/20"] = lambda item: self.alpaca_summarization_20(self.alpaca_preprocess(item))
        
        self.instr2preprocessor["Alpaca/Unobserved/1"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[0])
        self.instr2preprocessor["Alpaca/Unobserved/2"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[1])
        self.instr2preprocessor["Alpaca/Unobserved/3"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[2])
        self.instr2preprocessor["Alpaca/Unobserved/4"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[3])
        self.instr2preprocessor["Alpaca/Unobserved/5"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[4])
        self.instr2preprocessor["Alpaca/Unobserved/6"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[5])
        self.instr2preprocessor["Alpaca/Unobserved/7"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[6])
        self.instr2preprocessor["Alpaca/Unobserved/8"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[7])
        self.instr2preprocessor["Alpaca/Unobserved/9"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[8])
        self.instr2preprocessor["Alpaca/Unobserved/10"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[9])
        
        self.instr2preprocessor["FLAN/XSUM/1"] = lambda item: self.flan_xsum_1(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/2"] = lambda item: self.flan_xsum_2(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/3"] = lambda item: self.flan_xsum_3(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/4"] = lambda item: self.flan_xsum_4(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/5"] = lambda item: self.flan_xsum_5(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/6"] = lambda item: self.flan_xsum_6(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/7"] = lambda item: self.flan_xsum_7(dict(**item, id=0))
        self.instr2preprocessor["FLAN/XSUM/8"] = lambda item: self.flan_xsum_8(dict(**item, id=0))
        
        self.instr2preprocessor["NIV2/XSUM/1"] = lambda item: self.niv2_xsum_1(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/2"] = lambda item: self.niv2_xsum_2(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/3"] = lambda item: self.niv2_xsum_3(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/4"] = lambda item: self.niv2_xsum_4(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/5"] = lambda item: self.niv2_xsum_5(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/6"] = lambda item: self.niv2_xsum_6(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/7"] = lambda item: self.niv2_xsum_7(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/8"] = lambda item: self.niv2_xsum_8(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/9"] = lambda item: self.niv2_xsum_9(dict(**item, id=0))
        self.instr2preprocessor["NIV2/XSUM/10"] = lambda item: self.niv2_xsum_10(dict(**item, id=0))
        
        self.instr2preprocessor["FLAN/CNN/1"] = lambda item: self.flan_cnndm_1(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/2"] = lambda item: self.flan_cnndm_2(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/3"] = lambda item: self.flan_cnndm_3(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/4"] = lambda item: self.flan_cnndm_4(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/5"] = lambda item: self.flan_cnndm_5(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/6"] = lambda item: self.flan_cnndm_6(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/7"] = lambda item: self.flan_cnndm_7(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["FLAN/CNN/8"] = lambda item: self.flan_cnndm_8(dict(article=item["document"], highlights=item["summary"], id=0))
        
        self.instr2preprocessor["NIV2/CNN/1"] = lambda item: self.niv2_cnn_dm_1(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/2"] = lambda item: self.niv2_cnn_dm_2(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/3"] = lambda item: self.niv2_cnn_dm_3(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/4"] = lambda item: self.niv2_cnn_dm_4(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/5"] = lambda item: self.niv2_cnn_dm_5(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/6"] = lambda item: self.niv2_cnn_dm_6(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/7"] = lambda item: self.niv2_cnn_dm_7(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/8"] = lambda item: self.niv2_cnn_dm_8(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/9"] = lambda item: self.niv2_cnn_dm_9(dict(article=item["document"], highlights=item["summary"], id=0))
        self.instr2preprocessor["NIV2/CNN/10"] = lambda item: self.niv2_cnn_dm_10(dict(article=item["document"], highlights=item["summary"], id=0))


    def unobserved1(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}\n\nSummarize the essense of the dialogue above:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved2(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}.\n\nWhat is the abstract of this dialogue?\n\nAbstract:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved3(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"You are provided with a dialoge. The task is to summarize the dialogue into a few sentences.\n\nDialogue: {input_text}.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved4(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Task: Summarize the given text.\n\nText: {input_text}.\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved5(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Dialogue: {input_text}. Summarize it into a short paragraph.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved6(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}\n\nQuestion: what is discussed in this dialogue? Answer in  1-2 sentences.\n\nAnswer:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved7(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}. Summarization: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved8(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}. Using your understanding of the text, summarize it into a few sentences.\n\nOutput: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved9(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"You are an costumer service agent. You are asked to summarize the dialogue below into a few sentences.\n\nDialogue: {input_text}.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}

    def unobserved10(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Summarize the dialogue below into a few sentences. Try to use the words from the dialogue as much as possible\n\nDialogue: {input_text}.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}



    