import codecs
import numpy as np
import torch.nn as nn
import torch
import pdb
import torchtext
from src.tlp_rnn_fusion import torchtext_datasets


def load_glove_embeddings(glove_path):
    """Loads embedings, returns weight matrix and dict from words to indices."""
    print('loading word embeddings from %s' % glove_path)
    weight_vectors = []
    word_idx = {}
    with codecs.open(glove_path, encoding='utf-8') as f:
        for line in f:
            word, vec = line.split(u' ', 1)
            word_idx[word] = len(weight_vectors)
            weight_vectors.append(np.array(vec.split(), dtype=np.float32))
    word_idx['<PAD>'] = len(weight_vectors)
    weight_vectors.append(np.zeros(weight_vectors[0].shape).astype(np.float32))
    word_idx['<UNK>'] = len(weight_vectors)
    weight_vectors.append(np.random.uniform(
      -0.05, 0.05, weight_vectors[0].shape).astype(np.float32))
    # padding 
    return np.stack(weight_vectors), word_idx


class GloveEmbedding:
    def __init__(self, glove_path=None, cuda=True):
        self.weight_matrix, self.word_idx = load_glove_embeddings(glove_path)
        shape = self.weight_matrix.shape
        self.embedding = nn.Embedding(shape[0], shape[1])
        self.embedding.weight.data = torch.tensor(self.weight_matrix)
        if torch.cuda.is_available() and cuda:
            self.embedding.cuda()
        print("Loaded glove embeddings from : {}".format(glove_path))

    def get_word_idx(self):
        return self.word_idx
        
    def get_word_seq(self, tokens, max_seq_len):
        #tokens = tokens[:max_seq_len]
        tokens = ['<PAD>'] * (max_seq_len - len(tokens)) + tokens
        word_ids = []
        for i in range(max_seq_len):
            try:
                word_id = self.word_idx[tokens[i]]
            except KeyError:
                word_id = self.word_idx['<UNK>']
            word_ids.append(word_id)
        return np.array(word_ids)

    def get_batch_embedding(self, x):
        with torch.no_grad():
            return self.embedding(x)


class CompactGloveEmbedding:
    def __init__(self, glove_path=None, cuda=True, dataset_name=None, data_path=None):
        unique_tokens = self.get_unique_tokens(dataset_name, data_path)
        self.weight_matrix, self.word_idx = self.preprocess_embeddings(unique_tokens, glove_path)
        shape = self.weight_matrix.shape
        self.embedding = nn.Embedding(shape[0], shape[1])
        self.embedding.weight.data = torch.tensor(self.weight_matrix)
        if torch.cuda.is_available() and cuda:
            self.embedding.cuda()
        print("Loaded glove embeddings from : {}".format(glove_path))
        print("Embedding size: {}".format(self.embedding.weight.size()))

    def get_unique_tokens(self, dataset_name, data_path):
        token_freq = {}
        tokens = []
        if dataset_name == 'UDPOS':
            trainset, _, _ = torchtext.datasets.UDPOS(data_path)
            for line in trainset:
                for token in line[0]:
                    token = token.lower()
                    try:
                        token_freq[token] += 1
                    except KeyError:
                        token_freq[token] = 1
            for t, c in token_freq.items():
                if c >= 2:
                    tokens.append(t)
        elif dataset_name in ['AG_NEWS', 'SogouNews', 'DBpedia', 'YelpReviewPolarity', 'YelpReviewFull', 'YahooAnswers', 'AmazonReviewPolarity', 'AmazonReviewFull']:
            trainset, _ = torchtext_datasets.ClassificationDatasetSplits(data_path).splits()
            for line in trainset:
                for token in line[1]:
                    try:
                        token_freq[token] += 1
                    except KeyError:
                        token_freq[token] = 1
            min_freq = 2
            if dataset_name in ['DBpedia']:
                min_freq = 2
            for t, c in token_freq.items():
                if c >= min_freq:
                    tokens.append(t)
        elif dataset_name in ['SSTPT']:
            data, _, _ = torchtext.datasets.SST2(data_path)
            for phrase, label in data:
                for token in phrase.split(" "):
                    try:
                        token_freq[token] += 1
                    except KeyError:
                        token_freq[token] = 1
            for t, c in token_freq.items():
                if c >= 2:
                    tokens.append(t)
        else:
            raise NotImplementedError
        tokens.append('<PAD>')
        tokens.append('<UNK>')
        print('Extracted tokens: ', len(tokens))
        return tokens

    def preprocess_embeddings(self, tokens, glove_path):
        print("Processing embeddings again!")
        weight_matrix, word_idx = load_glove_embeddings(glove_path)
        new_word_idx = {}
        new_weight_vectors = []
        for i, token in enumerate(tokens):
            new_word_idx[token] = len(new_weight_vectors)
            try:
              id = word_idx[token]
              new_weight_vectors.append(weight_matrix[id])
            except KeyError:
              new_weight_vectors.append(np.random.uniform(-0.05, 0.05, weight_matrix.shape[1]).astype(np.float32))
        print("Done!")
        return np.stack(new_weight_vectors), new_word_idx

    def get_word_idx(self):
        return self.word_idx
        
    def get_word_seq(self, tokens, max_seq_len):
        #tokens = tokens[:max_seq_len]
        tokens = ['<PAD>'] * (max_seq_len - len(tokens)) + tokens
        word_ids = []
        for i in range(max_seq_len):
            try:
                word_id = self.word_idx[tokens[i]]
            except KeyError:
                word_id = self.word_idx['<UNK>']
            word_ids.append(word_id)
        return np.array(word_ids)

    def get_batch_embedding(self, x):
        with torch.no_grad():
            return self.embedding(x)


