import os.path
import json
from configs.preprocessor import Preprocessor
import multiprocessing as mp
import random
from datasets import Dataset, load_dataset


special_tokens = []


fewshot_examples = [1, 3, 5, 7, 9]


def _preprocess(item):
    new_item = {
        "sent1": item["translation"]["fr"],
        "sent2": item["translation"]["de"],
        "lang1": "french",
        "lang2": "german",
    }
    return new_item

def load_data(input_dir, instruction, shot_count, eval_by_logits, tokenizer):

    test_set = load_dataset("wmt19", "fr-de", split="validation") 
    test_set = test_set.map(_preprocess, remove_columns=["translation"])
    
    examples = load_dataset("wmt19", "fr-de", split="train").select(range(shot_count))
    examples = examples.map(_preprocess, remove_columns=["translation"])
        
    preprocessor = WMT19Preprocessor(instruction, examples, eval_by_logits, input_dir)
    preprocess = preprocessor.processor

    test_set = test_set.map(preprocess, remove_columns=["sent1", "sent2", "lang1", "lang2"], num_proc=1)
    return test_set


class WMT19Preprocessor(Preprocessor):

    _UNOBSERVED_ALPACA_INSTRUCTIONS = [
        "Translation",
        "You are a translator. Translate the following sentence into Germen:",
        "Translate this from French to German",
        "You are given a sentence in French. Your task is to translate it into German",
        "Translate the following sentence into German",
        "Instructions: Translate the sentence into German:",
        "Task definition: Translate from French to German.",
        "What does this input mean in German?",
        "Show me the equivalent of this input in German.",
        "Express this input in German for me please.",
    ]

    def __init__(self, instruction, examples, eval_by_logits, input_dir):
        super(WMT19Preprocessor, self).__init__(instruction, examples, eval_by_logits, input_dir)
        

    def add_unobserved_instructions(self):
        self.instr2preprocessor["WMT19/Unobserved/1"] = self.unobserved1
        self.instr2preprocessor["WMT19/Unobserved/2"] = self.unobserved2
        self.instr2preprocessor["WMT19/Unobserved/3"] = self.unobserved3
        self.instr2preprocessor["WMT19/Unobserved/4"] = self.unobserved4
        self.instr2preprocessor["WMT19/Unobserved/5"] = self.unobserved5
        self.instr2preprocessor["WMT19/Unobserved/6"] = self.unobserved6
        self.instr2preprocessor["WMT19/Unobserved/7"] = self.unobserved7
        self.instr2preprocessor["WMT19/Unobserved/8"] = self.unobserved8
        self.instr2preprocessor["WMT19/Unobserved/9"] = self.unobserved9
        self.instr2preprocessor["WMT19/Unobserved/10"] = self.unobserved10
        
        self.instr2preprocessor["Alpaca/Unobserved/1"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[0])
        self.instr2preprocessor["Alpaca/Unobserved/2"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[1])
        self.instr2preprocessor["Alpaca/Unobserved/3"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[2])
        self.instr2preprocessor["Alpaca/Unobserved/4"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[3])
        self.instr2preprocessor["Alpaca/Unobserved/5"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[4])
        self.instr2preprocessor["Alpaca/Unobserved/6"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[5])
        self.instr2preprocessor["Alpaca/Unobserved/7"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[6])
        self.instr2preprocessor["Alpaca/Unobserved/8"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[7])
        self.instr2preprocessor["Alpaca/Unobserved/9"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[8])
        self.instr2preprocessor["Alpaca/Unobserved/10"] = lambda item: self.alpaca_trans(item, self._UNOBSERVED_ALPACA_INSTRUCTIONS[9])

    def unobserved1(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Translation:\n\n{lang1}: {sent1}\n{lang2}: "
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved2(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"You are a translator. Translate the following sentence into {lang2}:\n\n{sent1}"
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved3(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Translate \"{sent1}\" from {lang1} to {lang2}:"
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved4(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"You are given a sentence in {lang1}. Your task is to translate it into {lang2}:\n\n{sent1}\n\nTranslation: "
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved5(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Translate the following sentence into {lang2}:\n\n{sent1}"
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved6(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Instructions: Translate the sentence into {lang2}:\n\nSentence: {sent1}\n\nTranslation: "
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved7(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Task definition: Translate from {lang1} to {lang2}:\n\nSentence: {sent1}\n\nTranslation:"
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved8(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"What does \"{sent1}\" mean in {lang2}?"
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved9(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Show me the equivalent of \"{sent1}\" in {lang2}."
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved9(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"Express \"{sent1}\" in {lang2} for me please."
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}
    
    def unobserved10(self, item):
        sent1, sent2, lang1, lang2 = item["sent1"], item["sent2"], item["lang1"], item["lang2"]
        input_text = f"In {lang1}, we say \"{sent1}\". In {lang2}, we say:"
        output_text = sent2
        return {"input_text": input_text, "output_text": output_text, "label_space": None}


    