import os.path
import json
from configs.preprocessor import Preprocessor
import multiprocessing as mp
import random
from datasets import Dataset, load_dataset


special_tokens = []


fewshot_examples = [1, 3, 5, 7, 9]

# ['document', 'summary', 'id'],

def load_data(input_dir, instruction, shot_count, eval_by_logits, tokenizer):

    test_set = load_dataset("xsum", "3.0.0", split="test")
    examples = load_dataset("xsum", "3.0.0", split="validation").select(range(shot_count))
    
    test_set = test_set.filter(lambda example: len(tokenizer(example["document"])["input_ids"]) < 512)
    preprocessor = XSUMPreprocessor(instruction, examples, eval_by_logits, input_dir)
    preprocess = preprocessor.processor

    test_set = test_set.map(preprocess, remove_columns=['document', 'summary', 'id'], num_proc=1)
    return test_set


class XSUMPreprocessor(Preprocessor):

    _UNOBSERVED_ALPACA_INSTRUCTIONS = [
        "You are given a short paragraph. In your own word, summarize it into a few sentence.",
        "Your task: summarize the above article into a few sentences.",
        "I need you to summarize this document for me. Your summarization should be as simple as possible.",
        "Summarization.",
        "You are an expert at document summarization. Please summarize this.",
        "The abstract of this text is: ",
        "Instruction: read the following article carefully, and provide a brief summarization.",
        "Summarize this text into sentences. Requirement: the summarization should be about 1-2 sentences.",
        "You are asked to summarize this text."
    ]

    def __init__(self, instruction, examples, eval_by_logits, input_dir):
        super(XSUMPreprocessor, self).__init__(instruction, examples, eval_by_logits, input_dir)
        
    def alpaca_preprocess(self, item):
        new_item = {
            "input_text": item["document"],
            "output_text": item["summary"],
        }
        return new_item

    def add_unobserved_instructions(self):
        self.instr2preprocessor["XSUM/Unobserved/1"] = self.unobserved1
        self.instr2preprocessor["XSUM/Unobserved/2"] = self.unobserved2
        self.instr2preprocessor["XSUM/Unobserved/3"] = self.unobserved3
        self.instr2preprocessor["XSUM/Unobserved/4"] = self.unobserved4
        self.instr2preprocessor["XSUM/Unobserved/5"] = self.unobserved5
        self.instr2preprocessor["XSUM/Unobserved/6"] = self.unobserved6
        self.instr2preprocessor["XSUM/Unobserved/7"] = self.unobserved7
        self.instr2preprocessor["XSUM/Unobserved/8"] = self.unobserved8
        self.instr2preprocessor["XSUM/Unobserved/9"] = self.unobserved9
        self.instr2preprocessor["XSUM/Unobserved/10"] = self.unobserved10
        
        self.instr2preprocessor["Alpaca/Observed/1"] = lambda item: self.alpaca_summarization_1(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/2"] = lambda item: self.alpaca_summarization_2(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/3"] = lambda item: self.alpaca_summarization_3(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/4"] = lambda item: self.alpaca_summarization_4(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/5"] = lambda item: self.alpaca_summarization_5(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/6"] = lambda item: self.alpaca_summarization_6(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/7"] = lambda item: self.alpaca_summarization_7(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/8"] = lambda item: self.alpaca_summarization_8(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/9"] = lambda item: self.alpaca_summarization_9(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/10"] = lambda item: self.alpaca_summarization_10(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/11"] = lambda item: self.alpaca_summarization_11(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/12"] = lambda item: self.alpaca_summarization_12(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/13"] = lambda item: self.alpaca_summarization_13(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/14"] = lambda item: self.alpaca_summarization_14(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/15"] = lambda item: self.alpaca_summarization_15(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/16"] = lambda item: self.alpaca_summarization_16(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/17"] = lambda item: self.alpaca_summarization_17(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/18"] = lambda item: self.alpaca_summarization_18(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/19"] = lambda item: self.alpaca_summarization_19(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/20"] = lambda item: self.alpaca_summarization_20(self.alpaca_preprocess(item))
        
        self.instr2preprocessor["Alpaca/Unobserved/1"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[0])
        self.instr2preprocessor["Alpaca/Unobserved/2"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[1])
        self.instr2preprocessor["Alpaca/Unobserved/3"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[2])
        self.instr2preprocessor["Alpaca/Unobserved/4"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[3])
        self.instr2preprocessor["Alpaca/Unobserved/5"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[4])
        self.instr2preprocessor["Alpaca/Unobserved/6"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[5])
        self.instr2preprocessor["Alpaca/Unobserved/7"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[6])
        self.instr2preprocessor["Alpaca/Unobserved/8"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[7])
        self.instr2preprocessor["Alpaca/Unobserved/9"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[8])
        
        
    def unobserved1(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"You are given a short paragraph: {input_text}\n\In your own word, summarize it into a few sentence:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved2(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}\nYour task: summarize the above article into a few sentences.\n\n"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved3(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"I need you to summarize this document for me: {input_text}.\n\nYour summarization should be as simple as possible.\n\n Summarization: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved4(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Document: {input_text}\n\nSummarization: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved5(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"You are an expert at document summarization. Please summarize this: {input_text}.\n\n\nOutput: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved6(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Text: {input_text}.\nThe abstract of this text is: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved7(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"{input_text}. Abstract: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved8(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Instruction: read the following article carefully, and provide a brief summarization: {input_text}.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved9(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"Summarize this text into sentences: {input_text}.\n\nRequirement: the summarization should be about 1-2 sentences.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}

    def unobserved10(self, item):
        input_text, output_text  = item["document"], item["summary"]
        input_text = f"You are asked to summarize this text: {input_text}.\n\nSummarization:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}



    