import torch
from datasets import load_dataset

def postprocess_output(output_ids, eos_token_id):
    """
    output_ids : list
    eos_token_id : int
    """
    
    if eos_token_id in output_ids:
        eos_index = output_ids.index(eos_token_id)
        return output_ids[:eos_index+1]
    else:
        return output_ids

    
def postprocess_input(input_ids, bos_token_id):
    """
    input_ids : list
    bos_token_id : int
    """
    
    bos_index = input_ids.index(bos_token_id)
    return input_ids[bos_index:]


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.indexes = [i for i in range(len(dataset)) if len(dataset[i]["input_ids"]) < 4096]
        
    def __len__(self):
        return len(self.indexes)
        
        
    def __getitem__(self, idx):
        #return self.dataset[self.indexes[idx]]
        length = self.dataset[self.indexes[idx]]["input_ids"].shape[0]
        input_length = int(0.9 * length)
        #return {"input_ids": self.dataset[self.indexes[idx]]["input_ids"][:input_length]} #, "last_ids": self.dataset[self.indexes[idx]]["input_ids"][input_length:]}
        return {"input_ids": self.dataset[self.indexes[idx]]["input_ids"][:input_length], "target_ids": self.dataset[self.indexes[idx]]["input_ids"][input_length:]}


class WMT16_DeEn(torch.utils.data.Dataset):
    def __init__(self, tokenizer, path="/data1/dataset/"):
        self.dataset = load_dataset("wmt16", "de-en", split="test", cache_dir=path)
        self.dataset = self.dataset.map(lambda example: {"input_ids": tokenizer(example["translation"]["de"])["input_ids"],
                                                    "target_ids": tokenizer(example["translation"]["en"])["input_ids"]})
        
        self.dataset.set_format(type="torch", columns=["input_ids", "target_ids"])

    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):        
        return {"input_ids": self.dataset[idx]["input_ids"], "target_ids": self.dataset[idx]["target_ids"]}


class WMT16_EnDe(torch.utils.data.Dataset):
    def __init__(self, tokenizer, path="/data1/dataset/"):
        self.dataset = load_dataset("wmt16", "de-en", split="test", cache_dir=path)
        self.dataset = self.dataset.map(lambda example: {"input_ids": tokenizer(example["translation"]["en"])["input_ids"],
                                                    "target_ids": tokenizer(example["translation"]["de"])["input_ids"]})
        
        self.dataset.set_format(type="torch", columns=["input_ids", "target_ids"])

    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):        
        return {"input_ids": self.dataset[idx]["input_ids"], "target_ids": self.dataset[idx]["target_ids"]}



class WMT14_FrEn(torch.utils.data.Dataset):
    def __init__(self, tokenizer, path="/data1/dataset/"):
        self.dataset = load_dataset("wmt14", "fr-en", split="test", cache_dir=path)
        self.dataset = self.dataset.map(lambda example: {"input_ids": tokenizer(example["translation"]["fr"])["input_ids"],
                                                    "target_ids": tokenizer(example["translation"]["en"])["input_ids"]})
        
        self.dataset.set_format(type="torch", columns=["input_ids", "target_ids"])

    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):        
        return {"input_ids": self.dataset[idx]["input_ids"], "target_ids": self.dataset[idx]["target_ids"]}


    
class WMT14_EnFr(torch.utils.data.Dataset):
    def __init__(self, tokenizer, path="/data1/dataset/"):
        self.dataset = load_dataset("wmt14", "fr-en", split="test", cache_dir=path)
        self.dataset = self.dataset.map(lambda example: {"input_ids": tokenizer(example["translation"]["en"])["input_ids"],
                                                    "target_ids": tokenizer(example["translation"]["fr"])["input_ids"]})
        
        self.dataset.set_format(type="torch", columns=["input_ids", "target_ids"])

    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):        
        return {"input_ids": self.dataset[idx]["input_ids"], "target_ids": self.dataset[idx]["target_ids"]}
