import pdb

from torch.utils.data import Dataset, DataLoader
import torch
import os
import numpy as np
from src.tlp_rnn_fusion.rnn_utils import *
import codecs
import torch.nn as nn
import torchtext

from src.tlp_rnn_fusion import embedding
from src.tlp_rnn_fusion import torchtext_datasets
    

class HomogeneousMNIST(Dataset):
    def __init__(self, data_path,transform=None):
        data_file = np.genfromtxt(data_path,skip_header=0,dtype=float,delimiter=',')
        self.dataset = torch.from_numpy(data_file)#.double()
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        x_idx = self.dataset[idx,1:].float().reshape(28,28) # convert the size of input to (28,28)
        if not self.transform is None:
            x_idx = self.transform(x_idx.unsqueeze(0)).squeeze()
        label_idx = self.dataset[idx,0].long()

        return x_idx,label_idx


def load_glove_embeddings(glove_path):
    """Loads embedings, returns weight matrix and dict from words to indices."""
    print('loading word embeddings from %s' % glove_path)
    weight_vectors = []
    word_idx = {}
    with codecs.open(glove_path, encoding='utf-8') as f:
        for line in f:
            word, vec = line.split(u' ', 1)
            word_idx[word] = len(weight_vectors)
            weight_vectors.append(np.array(vec.split(), dtype=np.float32))
    word_idx['<PAD>'] = len(weight_vectors)
    weight_vectors.append(np.zeros(weight_vectors[0].shape).astype(np.float32))
    word_idx['<UNK>'] = len(weight_vectors)
    weight_vectors.append(np.random.uniform(
      -0.05, 0.05, weight_vectors[0].shape).astype(np.float32))
    # padding 
    return np.stack(weight_vectors), word_idx


class SSTDatasetPT(Dataset):
    def __init__(self, data_path, glove_embedding, train=True, max_seq_len=56):
        self.max_seq_len = max_seq_len
        self.dataset = []
        if train:
            data, _, _ = torchtext.datasets.SST2(data_path)
        else:
            _, data, _ = torchtext.datasets.SST2(data_path)
        for phrase, label in data:
            self.dataset.append([int(label), phrase.split(" ")])
        self.glove_embedding = glove_embedding
        self.word_indices = []
        self.preload_indices()
        print("Dataset setup!")

    def __len__(self):
        return len(self.dataset)

    def preload_indices(self):
        print("Loading indices")
        for idx in range(0, len(self.dataset)):
            word_ids = self.glove_embedding.get_word_seq(self.dataset[idx][1], self.max_seq_len)
            self.word_indices.append(word_ids)
    
    def __getitem__(self, idx):
        label_idx = torch.tensor(self.dataset[idx][0]).long()
        word = torch.tensor(self.word_indices[idx]).long()
        return word, label_idx


class ClassificationDataset(Dataset):
    def __init__(self, root_dir, glove_embedding, tag='train', max_seq_len=1):
        self.max_seq_len = max_seq_len
        self.splits = torchtext_datasets.ClassificationDatasetSplits(root_dir)
        self.dataset = []
        if tag == 'train':
            data, _ = self.splits.splits()
        elif tag == 'test':
            _, data = self.splits.splits()
        else:
            raise NotImplementedError
        self.data = data
        self.glove_embedding = glove_embedding
        self.word_indices = []
        self.labels = []
        self.preload_indices()
        print("Dataset setup!")
    
    def __len__(self):
        return len(self.word_indices)

    def preload_indices(self):
        print('Loading indices!')
        for line in self.data:
            # Preprocess texts
            word_ids = self.glove_embedding.get_word_seq(line[1], self.max_seq_len)
            self.word_indices.append(word_ids)
            self.labels.append(line[0])
    
    def __getitem__(self, idx):
        label = torch.tensor(self.labels[idx]).long()
        word_ids = torch.tensor(self.word_indices[idx]).long()
        return word_ids, label


"""
Testing
"""

#if __name__ == "__main__":
    #dataset = HomogeneousMNIST("./data/homogeneous_MNIST/mnist_train_0.csv")
    #loader = DataLoader(dataset, batch_size=10)
    #pdb.set_trace()
    # for i_batch, samples_batched in enumerate(iter(loader)):
    #     if i_batch <3:
    #         x_batched,y_batched = samples_batched
    #         print(x_batched.size())
    #         print(x_batched.dtype)
    #         print(y_batched.size())
    #         print(y_batched.dtype)

