import os.path
import json
from configs.preprocessor import Preprocessor
import multiprocessing as mp
import random
from datasets import Dataset, load_dataset


special_tokens = []


fewshot_examples = [1, 3, 5, 7, 9]


def load_data(input_dir, instruction, shot_count, eval_by_logits, tokenizer):

    test_set = load_dataset("cnn_dailymail", "3.0.0", split="test")
    examples = load_dataset("cnn_dailymail", "3.0.0", split="validation").select(range(shot_count))
    
    test_set = test_set.filter(lambda example: len(tokenizer(example["article"])["input_ids"]) < 512)
    preprocessor = CNNDMPreprocessor(instruction, examples, eval_by_logits, input_dir)
    preprocess = preprocessor.processor

    test_set = test_set.map(preprocess, remove_columns=["article", "highlights", "id"], num_proc=1)
    return test_set


class CNNDMPreprocessor(Preprocessor):

    _UNOBSERVED_ALPACA_INSTRUCTIONS = [
        "Summarize the following article into a few sentences",
        "Read the following news article from CNN/DailyMail. Summarize the article into less than 10 lines",
        "Summarize this new article into highlights with no more than 10 sentences.",
        "Summarize the main points of the above article in a few sentences.",
        "Please provide a brief summary of the above article",
        "I need you to give me a short overview of the following article.",
        "Consider the above news article and write a short summary based on it"   
    ]

    def __init__(self, instruction, examples, eval_by_logits, input_dir):
        super(CNNDMPreprocessor, self).__init__(instruction, examples, eval_by_logits, input_dir)
        
    def alpaca_preprocess(self, item):
        new_item = {
            "input_text": item["article"],
            "output_text": item["highlights"],
        }
        return new_item

    def add_unobserved_instructions(self):
        self.instr2preprocessor["CNN/Unobserved/1"] = self.unobserved1
        self.instr2preprocessor["CNN/Unobserved/2"] = self.unobserved2
        self.instr2preprocessor["CNN/Unobserved/3"] = self.unobserved3
        self.instr2preprocessor["CNN/Unobserved/4"] = self.unobserved4
        self.instr2preprocessor["CNN/Unobserved/5"] = self.unobserved5
        self.instr2preprocessor["CNN/Unobserved/6"] = self.unobserved6
        self.instr2preprocessor["CNN/Unobserved/7"] = self.unobserved7
        
        self.instr2preprocessor["Alpaca/Observed/1"] = lambda item: self.alpaca_summarization_1(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/2"] = lambda item: self.alpaca_summarization_2(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/3"] = lambda item: self.alpaca_summarization_3(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/4"] = lambda item: self.alpaca_summarization_4(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/5"] = lambda item: self.alpaca_summarization_5(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/6"] = lambda item: self.alpaca_summarization_6(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/7"] = lambda item: self.alpaca_summarization_7(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/8"] = lambda item: self.alpaca_summarization_8(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/9"] = lambda item: self.alpaca_summarization_9(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/10"] = lambda item: self.alpaca_summarization_10(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/11"] = lambda item: self.alpaca_summarization_11(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/12"] = lambda item: self.alpaca_summarization_12(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/13"] = lambda item: self.alpaca_summarization_13(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/14"] = lambda item: self.alpaca_summarization_14(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/15"] = lambda item: self.alpaca_summarization_15(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/16"] = lambda item: self.alpaca_summarization_16(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/17"] = lambda item: self.alpaca_summarization_17(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/18"] = lambda item: self.alpaca_summarization_18(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/19"] = lambda item: self.alpaca_summarization_19(self.alpaca_preprocess(item))
        self.instr2preprocessor["Alpaca/Observed/20"] = lambda item: self.alpaca_summarization_20(self.alpaca_preprocess(item))
        
        self.instr2preprocessor["Alpaca/Unobserved/1"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[0])
        self.instr2preprocessor["Alpaca/Unobserved/2"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[1])
        self.instr2preprocessor["Alpaca/Unobserved/3"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[2])
        self.instr2preprocessor["Alpaca/Unobserved/4"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[3])
        self.instr2preprocessor["Alpaca/Unobserved/5"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[4])
        self.instr2preprocessor["Alpaca/Unobserved/6"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[5])
        self.instr2preprocessor["Alpaca/Unobserved/7"] = lambda item: self.alpaca_nlg(self.alpaca_preprocess(item), self._UNOBSERVED_ALPACA_INSTRUCTIONS[6])
        


    def unobserved1(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"Summarize the following article into a few sentences: {input_text}\n\nSummary: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved2(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"Read the following news article from CNN/DailyMail: {input_text}\n\Summarize the article into less than 10 lines: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved3(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"{input_text}. Summarize this new article into highlights with no more than 10 sentences.\nHighlights: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved4(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"{input_text}\nSummarize the main points of the above article in a few sentences.\n"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved5(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"{input_text}. Please provide a brief summary of the above article:"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved6(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"I need you to give me a short overview of the following article.\n{input_text}\n Overview:\n"
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved7(self, item):
        input_text, output_text  = item["article"], item["highlights"]
        input_text = f"{input_text}. Consider the above news article and write a short summary based on it: "
        return {"input_text": input_text, "output_text": output_text, "label_space": None}



    