import os
import torch
from torch.utils.data import Dataset, TensorDataset
import re
from tqdm import tqdm
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, ElectraTokenizerFast
from datasets import load_dataset


class WNLIDataset():

    def __init__(self, mode):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = 100

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/wnli')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence1_list = train_data['sentence1']
            sentence2_list = train_data['sentence2']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence1_list = test_data['sentence1']
            sentence2_list = test_data['sentence2']
            label_list = test_data['label']

        return sentence1_list, sentence2_list, label_list

    def __len__(self):

        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class RTEDataset():

    def __init__(self, mode):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = 200

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/rte')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence1_list = train_data['sentence1']
            sentence2_list = train_data['sentence2']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence1_list = test_data['sentence1']
            sentence2_list = test_data['sentence2']
            label_list = test_data['label']

        print(set(label_list))

        return sentence1_list, sentence2_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class QNLIDataset():
    def __init__(self, mode):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = 120

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/qnli')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence1_list = train_data['question']
            sentence2_list = train_data['sentence']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence1_list = test_data['question']
            sentence2_list = test_data['sentence']
            label_list = test_data['label']

        print(set(label_list))

        return sentence1_list, sentence2_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)
        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class MnliMismatchedSDataset():

    def __init__(self, mode):
        self.mode = mode
        self.label_num = 3
        self.MAX_LEN = 200

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/mnli')
        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence1_list = train_data['premise']
            sentence2_list = train_data['hypothesis']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            data_ = load_dataset('data/glue/mnli_mismatched')
            test_data = data_['validation']

            sentence1_list = test_data['premise']
            sentence2_list = test_data['hypothesis']
            label_list = test_data['label']

        return sentence1_list, sentence2_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class MNLIDataset():

    def __init__(self, mode):
        self.mode = mode
        self.label_num = 3
        self.MAX_LEN = 200

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/mnli')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence1_list = train_data['premise']
            sentence2_list = train_data['hypothesis']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence1_list = test_data['premise']
            sentence2_list = test_data['hypothesis']
            label_list = test_data['label']

        return sentence1_list, sentence2_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class STSBDataset():
    def __init__(self, mode):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = 80

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/stsb')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence1_list = train_data['sentence1']
            sentence2_list = train_data['sentence2']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence1_list = test_data['sentence1']
            sentence2_list = test_data['sentence2']
            label_list = test_data['label']

        return sentence1_list, sentence2_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.float)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class QQPDataset():

    def __init__(self, mode):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = 100

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/qqp')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            question1_list = train_data['question1']
            question2_list = train_data['question2']
            label_list = train_data['label']

            if mode == 'train':
                question1_list = question1_list[:int(len(question1_list) * 0.9)]
                question2_list = question2_list[:int(len(question2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                question1_list = question1_list[int(len(question1_list) * 0.9):]
                question2_list = question2_list[int(len(question2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            question1_list = test_data['question1']
            question2_list = test_data['question2']
            label_list = test_data['label']

        return question1_list, question2_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class MRPCDataset():

    def __init__(self, mode):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = 100

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/mrpc')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()
            sentence1_list = train_data['sentence1']
            sentence2_list = train_data['sentence2']
            label_list = train_data['label']

            if mode == 'train':
                sentence1_list = sentence1_list[:int(len(sentence1_list) * 0.9)]
                sentence2_list = sentence2_list[:int(len(sentence2_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence1_list = sentence1_list[int(len(sentence1_list) * 0.9):]
                sentence2_list = sentence2_list[int(len(sentence2_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence1_list = test_data['sentence1']
            sentence2_list = test_data['sentence2']
            label_list = test_data['label']

        return sentence1_list, sentence2_list, label_list

    def __len__(self):

        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):

        text1_list, text2_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text1, text2, label in tqdm(zip(text1_list, text2_list, label_list)):
            encoded_input = tokenizer(text1, text2,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class CoLADataset():

    def __init__(self, mode):

        self.mode = mode
        self.label_num = 2

        self.MAX_LEN = 40

    def load_raw_data(self, mode):

        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        cola = load_dataset('data/glue/cola')

        if mode in ['train', 'dev']:

            train_data = cola['train'].shuffle()

            sentence_list = train_data['sentence']
            label_list = train_data['label']

            if mode == 'train':
                sentence_list = sentence_list[:int(len(sentence_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence_list = sentence_list[int(len(sentence_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = cola['validation']

            sentence_list = test_data['sentence']
            label_list = test_data['label']

        return sentence_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text, label in tqdm(zip(text_list, label_list)):
            encoded_input = tokenizer(text,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)


class SST2Dataset():

    def __init__(self, mode, max_len=45):
        self.mode = mode
        self.label_num = 2
        self.MAX_LEN = max_len

    def load_raw_data(self, mode):
        if mode not in ['train', 'dev', 'test']:
            raise ValueError('The value of mode can only be: train or test!')

        data_ = load_dataset('data/glue/sst2')

        if mode in ['train', 'dev']:
            train_data = data_['train'].shuffle()

            sentence_list = train_data['sentence']
            label_list = train_data['label']

            if mode == 'train':
                sentence_list = sentence_list[:int(len(sentence_list) * 0.9)]
                label_list = label_list[:int(len(label_list) * 0.9)]
            else:
                sentence_list = sentence_list[int(len(sentence_list) * 0.9):]
                label_list = label_list[int(len(label_list) * 0.9):]
        else:
            test_data = data_['validation']

            sentence_list = test_data['sentence']
            label_list = test_data['label']
        print(set(label_list))
        return sentence_list, label_list

    def __len__(self):
        return len(self.load_raw_data(self.mode)[0])

    def get_dataset(self, tokenizer):
        text_list, label_list = self.load_raw_data(self.mode)

        input_ids, attention_mask, token_type_ids = [], [], []

        for text, label in tqdm(zip(text_list, label_list)):
            encoded_input = tokenizer(text,
                                      padding='max_length',
                                      max_length=self.MAX_LEN,
                                      truncation=True,
                                      return_token_type_ids=True
                                      )
            input_ids.append(encoded_input.input_ids)
            attention_mask.append(encoded_input.attention_mask)
            token_type_ids.append(encoded_input.token_type_ids)

        input_ids = torch.tensor([i for i in input_ids], dtype=torch.long)
        attention_mask = torch.tensor([a for a in attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([t for t in token_type_ids], dtype=torch.long)
        label_list = torch.tensor([l for l in label_list], dtype=torch.long)

        return TensorDataset(input_ids, attention_mask, token_type_ids, label_list)
