import csv
import numpy as np
import re
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from sklearn.utils import shuffle
import pandas as pd
import pickle
import json
from gensim.models import KeyedVectors as Vectors

class DataHelper():
    def __init__(self, vocab_dict, sequence_max_length=1024):
        self.vocab_dict = vocab_dict
        self.sequence_max_length = sequence_max_length

    def char2vec(self, text):
        data = np.zeros(self.sequence_max_length)
        text = text.split()
        if len(text) > self.sequence_max_length :
            leng = self.sequence_max_length
        else:
            leng = len(text)
        for i in range(0, leng):
            data[i] = self.vocab_dict[text[i]]
        return np.array(data)


    def load_csv_file(self, filename, num_classes, s1, train=True, one_hot=False):

        all_data =np.zeros(shape=(s1, self.sequence_max_length), dtype=np.int)
        labels =np.zeros(shape=(s1, 1), dtype=np.int)
        with open(filename) as f:
            reader = csv.DictReader(f, fieldnames=['class'], restkey='fields')
            for i,row in enumerate(reader):
                if one_hot:
                    one_hot = np.zeros(num_classes)
                    one_hot[int(row['class']) - 1] = 1
                    labels[i] = one_hot
                else:
                    labels[i] = int(row['class']) - 1
                text = row['fields'][-1]
                all_data[i] = self.char2vec(text)
        f.close()
        return all_data, labels
    def init_embeddings(self,embeddings):
        bias = np.sqrt(3.0 / embeddings.size(1))
        torch.nn.init.uniform_(embeddings, -bias, bias)


    def load_embeddings(self,emb_file, emb_format, word_map):
        assert emb_format in {'glove', 'word2vec'}
        
        vocab = set(word_map.keys())
        
        print("Loading embedding...")
        cnt = 0 
        
        if emb_format == 'glove':
            
            with open(emb_file, 'r', encoding='utf-8') as f:
                emb_dim = len(f.readline().split(' ')) - 1 

            embeddings = torch.FloatTensor(len(vocab)+1, emb_dim)
            self.init_embeddings(embeddings)
            
            
            for line in open(emb_file, 'r', encoding='utf-8'):
                line = line.split(' ')
                emb_word = line[0]

                embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:])))

                if emb_word not in vocab:
                    continue
                else:
                    cnt+=1

                embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)

            print("Number of words read: ", cnt)
            print("Number of OOV: ", len(vocab) + 1 -cnt)

            return embeddings, emb_dim
        
        else:
            
            vectors = Vectors.load_word2vec_format(emb_file,binary=True)
            print("Load successfully")
            emb_dim = 300
            embeddings = torch.FloatTensor(len(vocab) + 1, emb_dim)
            self.init_embeddings(embeddings)
            
            for emb_word in vocab:
                
                if emb_word in vectors.index2word:
                    
                    embedding = vectors[emb_word]
                    cnt += 1
                    embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)
                    
                else:          
                    continue
                
            print("Number of words read: ", cnt)
            print("Number of OOV: ", len(vocab) + 1 -cnt)
            
            return embeddings, emb_dim

    def pretrain_embed(self, vocab,database_path):
            pretrain_embed, embed_dim = self.load_embeddings('GoogleNews-vectors-negative300.bin','word2vec',vocab)
            embed = dict()
            embed['pretrain'] = pretrain_embed
            embed['dim'] = embed_dim
            torch.save(embed, '{}glove_pretrain_embed.pth'.format(database_path))
    def load_dataset(self, dataset_path,train_len,text_len):
        with open(dataset_path+"classes.txt") as f:
            classes = []
            for line in f:
                classes.append(line.strip())
        f.close()
        num_classes = len(classes)
        train_data, train_label = self.load_csv_file(dataset_path + 'train.csv', num_classes, train_len)
        test_data, test_label = self.load_csv_file(dataset_path + 'test.csv', num_classes, text_len, train=False)

        return train_data, train_label, test_data, test_label
    def batch_iter(self, data, batch_size, num_epochs, sample=0 ,shuffle=True, classier = False):
        data_size = len(data)
        num_batches_per_epoch = int((len(data)-1)/batch_size) + 1
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data
        

        if classier:
            batches = []
            for i in range(len(shuffled_data)):
                if data[i][-1] == 1:
                    batches.append(shuffled_data[i])
                                                                                 
            shuffled_data = np.array(batches)
            data_size = len(shuffled_data)            
            num_batches_per_epoch = int((len(shuffled_data)-1)/batch_size) + 1


        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, data_size)
            batch = shuffled_data[start_index:end_index]
            batch_data, label = np.split(batch, [self.sequence_max_length],axis=1)
            yield np.array(batch_data, dtype=np.int), label



def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
    transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def get_cifar_loaders(batch_size=128, test_batch_size=1000):

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])


    train_loader = DataLoader(
        datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=3)

    test_loader = DataLoader(
        datasets.CIFAR10(root='.data/cifar', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=3)

    return train_loader, test_loader, None

def clean_str(string, TREC=False):
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)     
    string = re.sub(r"\'s", " \'s", string) 
    string = re.sub(r"\'ve", " \'ve", string) 
    string = re.sub(r"n\'t", " n\'t", string) 
    string = re.sub(r"\'re", " \'re", string) 
    string = re.sub(r"\'d", " \'d", string) 
    string = re.sub(r"\'ll", " \'ll", string) 
    string = re.sub(r",", " , ", string) 
    string = re.sub(r"!", " ! ", string) 
    string = re.sub(r"\(", " \( ", string) 
    string = re.sub(r"\)", " \) ", string) 
    string = re.sub(r"\?", " \? ", string) 
    string = re.sub(r"\s{2,}", " ", string)    
    return string.strip() if TREC else string.strip().lower()

def clean_str_sst(string):
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)   
    string = re.sub(r"\s{2,}", " ", string)    
    return string.strip().lower()

def read_MR():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/MR/rt-polarity.pos", "r", encoding="utf-8") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            line = clean_str(line)
            sequence = line.split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(2)
                x.append(line)           
                data.append(x)

    with open("data_source/MR/rt-polarity.neg", "r", encoding="utf-8") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            line = clean_str(line)
            sequence = line.split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(1)
                x.append(line)           
                data.append(x)

    data = shuffle(data)
    test_idx = len(data) // 10 * 9

    data_train = pd.DataFrame(data=data[:test_idx])
    data_test = pd.DataFrame(data=data[test_idx:])
    data_train.to_csv("data/MR/train.csv", index=None,header=None)
    data_test.to_csv("data/MR/test.csv", index=None,header=None)
    train_len = test_idx
    text_len = len(data) - train_len
    vocab_len = len(vocab)
    with open('data/MR/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab


def read_SUBJ():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/SUBJ/all.txt", "r", encoding="ISO-8859-1") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data.append(x)
    data = shuffle(data)
    test_idx = len(data) // 10 * 9

    data_train = pd.DataFrame(data=data[:test_idx])
    data_test = pd.DataFrame(data=data[test_idx:])
    data_train.to_csv("data/SUBJ/train.csv", index=None,header=None)
    data_test.to_csv("data/SUBJ/test.csv", index=None,header=None)
    train_len = test_idx
    text_len = len(data) - train_len
    vocab_len = len(vocab)
    with open('data/SUBJ/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab

def read_CR():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/CR/all.txt", "r", encoding="utf-8") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data.append(x)


    data = shuffle(data)
    test_idx = len(data) // 10 * 9

    data_train = pd.DataFrame(data=data[:test_idx])
    data_test = pd.DataFrame(data=data[test_idx:])
    data_train.to_csv("data/CR/train.csv", index=None,header=None)
    data_test.to_csv("data/CR/test.csv", index=None,header=None)
    train_len = test_idx
    text_len = len(data) - train_len
    vocab_len = len(vocab)
    with open('data/CR/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab

def read_MPQA():
    data = []
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/MPQA/all.txt", "r", encoding="utf-8") as f:
        for line in f:
            x = []
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data.append(x)
    data = shuffle(data)
    test_idx = len(data) // 10 * 9

    data_train = pd.DataFrame(data=data[:test_idx])
    data_test = pd.DataFrame(data=data[test_idx:])
    data_train.to_csv("data/MPQA/train.csv", index=None,header=None)
    data_test.to_csv("data/MPQA/test.csv", index=None,header=None)
    train_len = test_idx
    text_len = len(data) - train_len
    vocab_len = len(vocab)
    with open('data/MPQA/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab

def read_TREC():
    classes = {'DESC':1,'ENTY':2,'ABBR':3,'HUM':4,'LOC':5,'NUM':6}
    data_train, data_test = [],[]
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/TREC/TREC_train.txt", "r", encoding="ISO-8859-1") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0].split(":")[0]
            sequence = clean_str(" ".join(linet[1:]),TREC=True).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(classes[label])
                x.append(" ".join(sequence))           
                data_train.append(x)                
    with open("data_source/TREC/TREC_test.txt", "r", encoding="ISO-8859-1") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0].split(":")[0]
            sequence = clean_str(" ".join(linet[1:]),TREC=True).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(classes[label])
                x.append(" ".join(sequence))           
                data_test.append(x)
    data_train = shuffle(data_train)
    data_test = shuffle(data_test)

    data_train = pd.DataFrame(data=data_train)
    data_test = pd.DataFrame(data=data_test)
    data_train.to_csv("data/TREC/train.csv", index=None,header=None)
    data_test.to_csv("data/TREC/test.csv", index=None,header=None)
    train_len = len(data_train)
    text_len = len(data_test)
    vocab_len = len(vocab)

    with open('data/TREC/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab


def read_SST1():
    data_train, data_test = [],[]
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/SST1/train.txt", "r", encoding="utf-8") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str_sst(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data_train.append(x)                
    with open("data_source/SST1/test.txt", "r", encoding="utf-8") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str_sst(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data_test.append(x)
    data_train = shuffle(data_train)
    data_test = shuffle(data_test)

    data_train = pd.DataFrame(data=data_train)
    data_test = pd.DataFrame(data=data_test)
    data_train.to_csv("data/SST1/train.csv", index=None,header=None)
    data_test.to_csv("data/SST1/test.csv", index=None,header=None)
    train_len = len(data_train)
    text_len = len(data_test)
    vocab_len = len(vocab)
    with open('data/SST1/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab

def read_SST2():
    data_train, data_test = [],[]
    vocab = {}
    vocab_em = 1
    sequence_max_length = 0
    with open("data_source/SST2/train.txt", "r", encoding="ISO-8859-1") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str_sst(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data_train.append(x)                
    with open("data_source/SST2/test.txt", "r", encoding="ISO-8859-1") as f:
        for line in f:
            x = [] 
            if line[-1] == "\n":
                line = line[:-1]
            linet = line.split()
            label = linet[0]
            sequence = clean_str_sst(" ".join(linet[1:])).split()
            if len(sequence) != 0:
                for i in range(len(sequence)):
                    if sequence[i] not in vocab:
                        vocab.update({sequence[i]:vocab_em})
                        vocab_em += 1
                if sequence_max_length < len(sequence):
                    sequence_max_length = len(sequence)
                x.append(int(label)+1)
                x.append(" ".join(sequence))           
                data_test.append(x)
    data_train = shuffle(data_train)
    data_test = shuffle(data_test)

    data_train = pd.DataFrame(data=data_train)
    data_test = pd.DataFrame(data=data_test)
    data_train.to_csv("data/SST2/train.csv", index=None,header=None)
    data_test.to_csv("data/SST2/test.csv", index=None,header=None)
    train_len = len(data_train)
    text_len = len(data_test)
    vocab_len = len(vocab)
    with open('data/SST2/wordmap.json', 'w') as j:
        json.dump(vocab, j)
    return vocab_len + 1, sequence_max_length, train_len, text_len, vocab

if __name__ == '__main__':
    vocab_len, sequence_max_length, train_len, text_len, vocab = read_MR()

    database_path = 'MR_class/'
    data_helper = DataHelper(vocab, sequence_max_length=sequence_max_length)
    train_data, train_label, test_data, test_label = data_helper.load_dataset(database_path,train_len,text_len)
