import numpy as np
import math
import torch
from torch.utils.data import Dataset
import torch.nn as nn


BATCH_SIZE = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################################################################

class SingleDataset(Dataset):
    def __init__(self, tokens, counts):
        self.data = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())])) for tok, count in zip(tokens, counts)]

    def __len__(self):
        return( len(self.data) )

    def __getitem__(self, idx):
        return(self.data[idx])

def generate_single_batch(batch):
    text = batch
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)

    return text.long(), offsets.long()


#############################
#Contrastive 
class ContrastiveDataset(Dataset):
    def __init__(self, tokens_1, counts_1, tokens_2, counts_2, labels):
        self.data = list(zip(labels, 
                            [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok, count)])) for tok, count in zip(tokens_1, counts_1)],
                            [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok, count)])) for tok, count in zip(tokens_2, counts_2)]))

    def __len__(self):
        return( len(self.data) )

    def __getitem__(self, idx):
        return(self.data[idx])


def generate_contrast_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text_1 = [entry[1] for entry in batch]
    offsets_1 = [0] + [len(entry) for entry in text_1]
    offsets_1 = torch.tensor(offsets_1[:-1]).cumsum(dim=0)
    text_1 = torch.cat(text_1)

    text_2 = [entry[2] for entry in batch]
    offsets_2 = [0] + [len(entry) for entry in text_2]
    offsets_2 = torch.tensor(offsets_2[:-1]).cumsum(dim=0)
    text_2 = torch.cat(text_2)

    return text_1.long(), offsets_1.long(), text_2.long(), offsets_2.long(), label.float()

########### Reference Datasets #############################################
class ReferenceDataset(Dataset):
    def __init__(self, tokens, counts, ref_tokens, ref_counts):
        self.data = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())])) for tok, count in zip(tokens, counts)]
        self.refs = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok, count)])) for tok, count in zip(ref_tokens, ref_counts)]

    def __len__(self):
        return( len(self.data) * len(self.refs) )

    def __getitem__(self, idx):
        nrefs = len(self.refs)
        return(self.data[idx//nrefs], self.refs[idx%nrefs])

def generate_reference_batch(batch):
    text_1 = [entry[0] for entry in batch]
    offsets_1 = [0] + [len(entry) for entry in text_1]
    offsets_1 = torch.tensor(offsets_1[:-1]).cumsum(dim=0)
    text_1 = torch.cat(text_1)

    text_2 = [entry[1] for entry in batch]
    offsets_2 = [0] + [len(entry) for entry in text_2]
    offsets_2 = torch.tensor(offsets_2[:-1]).cumsum(dim=0)
    text_2 = torch.cat(text_2)

    return text_1.long(), offsets_1.long(), text_2.long(), offsets_2.long()



###############################
## Defines the type of hidden block we use
def hidden_block(dropout_p, batchnorm, in_size, out_size):
    if(batchnorm):
        return(
            nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.ReLU(),
                nn.BatchNorm1d(in_size),
                nn.Linear(in_size, out_size)
            )
        )
    else:
        return(
            nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.ReLU(),
                nn.Linear(in_size, out_size)
            )
        )

############################################################################
##  Contrastive model.
##    Input: A pair of documents (real or contrastive).
##    Output: A real number.

class ContrastiveModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, c_dim=100, h_dim=256, dropout_p=0.5, n_layers=6, vectors=None, fine_tune_vectors=True):
        super().__init__()
        self.embed_1 = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean', _weight=vectors)
        self.embed_2 = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean', _weight=vectors)

        sizes = [embed_dim] + [h_dim]*(n_layers-1) + [c_dim]
        batch_norms = [True] + [False]*(n_layers-1)
        dropouts = [dropout_p for k in range(0, n_layers)]

        hidden_blocks_1 = nn.Sequential(*[hidden_block(dropouts[i], batch_norms[i], sizes[i], sizes[i+1]) for i in range(n_layers)])
        hidden_blocks_2 = nn.Sequential(*[hidden_block(dropouts[i], batch_norms[i], sizes[i], sizes[i+1]) for i in range(n_layers)])

        self.layers_1 = nn.Sequential(
            hidden_blocks_1
        )
        self.layers_2 = nn.Sequential(
            hidden_blocks_2
        )

    def get_embedding(self, text, offsets):
        x = self.layers_1(self.embed_1(text, offsets))
        return(x)
    
    def forward(self, text_1, offsets_1, text_2, offsets_2):
        x1 = self.layers_1(self.embed_1(text_1, offsets_1))
        x2 = self.layers_2(self.embed_2(text_2, offsets_2))
        return( (x1*x2).sum(1, keepdim=True))



#############################################
#PredictDataset2 takes in no longer tokens counts but directly the expanded doc
######
class PredictDataset2(Dataset):
    def __init__(self, docs, labels):
        self.data = list(zip(labels,[torch.from_numpy(doc) for doc in docs]))

    def __len__(self):
        return( len(self.data) )

    def __getitem__(self, idx):
        return(self.data[idx])

# class AttnCTMDataset(Dataset):
#     def __init__(self, docs, labels):
#         self.data = list(zip(docs, labels))

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, item):
#         return self.data[item]

def generate_predict_batch2(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text_1 = [entry[1] for entry in batch]
    offsets_1 = [0] + [len(entry) for entry in text_1]
    offsets_1 = torch.tensor(offsets_1[:-1]).cumsum(dim=0)
    text_1 = torch.cat(text_1)
    return text_1.long(), offsets_1.long(), label.long()



###### ATTENTION

from collections import defaultdict
class ByLengthSampler(torch.utils.data.Sampler):
    """
    Allows for sampling minibatches of examples all of the same sequence length;
    adapted from https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284/13.
    """
    def __init__(self, dataset, batchsize, key=-1, shuffle=True, ignore_len = 0):
        # import ipdb
        # ipdb.set_trace()
        self.batchsize = batchsize
        self.shuffle = shuffle
        self.ignore_len = ignore_len
        if key is None:
            self.seqlens = torch.LongTensor([len(example) for example in dataset])
        else:
            self.seqlens = torch.LongTensor([len(example[key]) for example in dataset])
        self.nbatches = len(self._generate_batches())

    def _generate_batches(self):
        # shuffle examples
        seqlens = self.seqlens
        perm = torch.randperm(seqlens.size(0)) if self.shuffle else torch.arange(seqlens.size(0))
        batches = []
        len2batch = defaultdict(list)
        for i, seqidx in enumerate(perm):
            seqlen, seqidx = seqlens[seqidx].item(), seqidx.item()
            if seqlen < self.ignore_len: continue # ignore texts that are too short !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            len2batch[seqlen].append(seqidx)
            if len(len2batch[seqlen]) >= self.batchsize:
                batches.append(len2batch[seqlen][:])
                del len2batch[seqlen]
        # add any remaining batches
        for length, batchlist in len2batch.items():
            if len(batchlist) > 0:
                batches.append(batchlist)
        # shuffle again so we don't always start w/ the most common sizes
        batchperm = torch.randperm(len(batches)) if self.shuffle else torch.arange(len(batches))
        return [batches[idx] for idx in batchperm]

    def batch_count(self):
        return self.nbatches

    def __len__(self):
        return len(self.seqlens)

    def __iter__(self):
        batches = self._generate_batches()
        for batch in batches:
            yield batch

class PredictDataset2_attn(torch.utils.data.Dataset):
    def __init__(self, tokens, labels):
        #all_docs_expanded = []
        #for this_doc_tokens, this_doc_counts in zip(tokens, counts):
        #    this_doc = np.concatenate([np.repeat(t,c) for t,c in zip(this_doc_tokens, this_doc_counts)])
        #    all_docs_expanded.append(torch.from_numpy(this_doc))
        #self.data = list(zip(all_docs_expanded, labels))
        self.data = list(zip(labels,[torch.from_numpy(doc) for doc in tokens]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        if isinstance(item, list):               # changed !!!!!!!!!!!!!!!
            batch = [self.data[i] for i in item] # changed !!!!!!!!!!!!!!!
            return batch                         # changed !!!!!!!!!!!!!!!
        return self.data[item]


def generate_predict_batch2_attn(batch):
    labels = torch.tensor([entry[0] for entry in batch])
    texts = torch.cat([entry[1].unsqueeze(0) for entry in batch],dim=0) # changed !!!!!!!!!!!!!!!!!!!!
    return texts.long(), labels.long()


class SingleDataset_attn(Dataset):
    def __init__(self, tokens, counts,labels):
        #self.data = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())])) for tok, count in zip(tokens, counts)]
        docs = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())])) for tok, count in zip(tokens, counts)]
        self.data = list(zip(labels,docs))   
    def __len__(self):
        return( len(self.data) )

    def __getitem__(self, idx):
        if isinstance(idx, list):               # changed !!!!!!!!!!!!!!!
            batch = [self.data[i] for i in idx] # changed !!!!!!!!!!!!!!!
            return batch                         # changed !!!!!!!!!!!!!!!
        return(self.data[idx])

def generate_single_batch_attn(batch):

    labels = torch.tensor([entry[0] for entry in batch])
    texts = torch.cat([entry[1].unsqueeze(0) for entry in batch],dim=0)

    return texts.long(),labels.long()
'''
Predict Next Word Model with different flavors:

Model Identification id             Description
============================||======================||
base                          (resblock) end of last layer residual block as representation
sec2last                      (resblock) have another variable representation layer after last layer of resblock before the final layer as representation layer
word2vec                      (resblock) apply word2vec pretrained matrix after last layer output
base_rsm                      softmax(random square matrix * second to last layer)
lastlyaer                     representation is giant last layer + softmax
attention                       
'''

class ResBlock(nn.Module):
    def __init__(self,dropout,in_size):
        super().__init__()

        self.layer=nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_size,in_size),
            nn.ReLU()
        )
        self.bn_layer=nn.BatchNorm1d(in_size,affine=False)

    def forward(self,x):
        return self.bn_layer(self.layer(x)+x)

class PredictModel_lastlayer(nn.Module): 
    def __init__(self, vocab_size, embed_dim, h_dim=256,dropout_p=0.5, n_layers=6):
        super().__init__()
        # id_mat = np.zeros((vocab_size, vocab_size))
        # for i in range(vocab_size):
        #     id_mat[i,i] = 1
        # id_mat = torch.tensor(id_mat, dtype=torch.float32)
        # self.embed = nn.EmbeddingBag(vocab_size, vocab_size, mode='mean', _weight=id_mat)
        # self.embed.weight.requires_grad = False
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
        self.layers = nn.Sequential(
            hidden_block(dropout_p, True, embed_dim, h_dim),
            *[ResBlock(dropout_p, h_dim) for i in range(n_layers)],
            hidden_block(dropout_p, True, h_dim, vocab_size)
        )
        self.softmax = nn.Softmax(dim=1)

    def get_embedding(self, text, offsets):
        #no softmax
        x = self.softmax(self.layers(self.embed(text, offsets)))
        return(x)

    def forward(self, text, offsets):
        return self.layers(self.embed(text, offsets))





class PredictModel_base(nn.Module): 
    def __init__(self, vocab_size, embed_dim, h_dim=256,dropout_p=0.5, n_layers=6):
        super().__init__()
        # id_mat = np.zeros((vocab_size, vocab_size))
        # for i in range(vocab_size):
        #     id_mat[i,i] = 1
        # id_mat = torch.tensor(id_mat, dtype=torch.float32)
        # self.embed = nn.EmbeddingBag(vocab_size, vocab_size, mode='mean', _weight=id_mat)
        # self.embed.weight.requires_grad = False
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
        self.layers = nn.Sequential(
            hidden_block(dropout_p, True, embed_dim, h_dim),
            *[ResBlock(dropout_p, h_dim) for i in range(n_layers)],
        )
        self.lastlayer=hidden_block(dropout_p, True, h_dim, vocab_size)
        self.softmax = nn.Softmax(dim=1)

    def get_embedding(self, text, offsets):
        #no softmax
        x = self.softmax(self.layers(self.embed(text, offsets)))
        return(x)

    def forward(self, text, offsets):
        return self.lastlayer(self.layers(self.embed(text, offsets)))

class PredictModel_rsm(nn.Module): 
    def __init__(self, vocab_size, embed_dim, h_dim=256,dropout_p=0.5, n_layers=6,rs_matrix=None):
        super().__init__()
        # id_mat = np.zeros((vocab_size, vocab_size))
        # for i in range(vocab_size):
        #     id_mat[i,i] = 1
        # id_mat = torch.tensor(id_mat, dtype=torch.float32)
        # self.embed = nn.EmbeddingBag(vocab_size, vocab_size, mode='mean', _weight=id_mat)
        # self.embed.weight.requires_grad = False
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
        self.layers = nn.Sequential(
            hidden_block(dropout_p, True, embed_dim, h_dim),
            *[ResBlock(dropout_p, h_dim) for i in range(n_layers)],
        )
        self.lastlayer=hidden_block(dropout_p, True, h_dim, vocab_size)
        self.softmax = nn.Softmax(dim=1)
        self.rs_matrix=rs_matrix

    def get_embedding(self, text, offsets):
        #random square matrix implementation
        output=self.layers(self.embed(text, offsets)).cpu() 
        x = self.softmax(torch.matmul(output,torch.tensor(self.rs_matrix).float())) 
        # output shape (batchsize, h_dim) rs_matrix dim: (h_dim, h_dim)
        # now change to (h_dim, 2*h_dim)
        return(x)


    def forward(self, text, offsets):
        return self.lastlayer(self.layers(self.embed(text, offsets)))






class PredictModel_sec2last(nn.Module): 
    def __init__(self, vocab_size, embed_dim, h_dim=256,representation_dim=4096, dropout_p=0.5, n_layers=6):
        super().__init__()
        # weight = torch.FloatTensor(np.diag([1]*5000))
        # self.embed= nn.EmbeddingBag.from_pretrained(weight)
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')

        self.layers = nn.Sequential(
            hidden_block(dropout_p, True, embed_dim, h_dim),
            *[ResBlock(dropout_p, h_dim) for i in range(n_layers)],
        )

        self.representation=hidden_block(dropout_p, True, h_dim, representation_dim)
        self.lastlayer=hidden_block(dropout_p, True, representation_dim, vocab_size)
        self.softmax = nn.Softmax(dim=1)

    # def get_embedding(self, text, offsets):
    #     x = self.softmax(self.lastlayer(self.layers(self.embed(text, offsets))))
    #     return(x)

    def get_embedding(self,text,offsets):
        # no softmax
        return self.representation(self.layers(self.embed(text, offsets)))

    def forward(self, text, offsets):
        return self.lastlayer(self.representation(self.layers(self.embed(text, offsets))))

class PredictModel_word2vec(nn.Module):
    def __init__(self, vocab_size, embed_dim, h_dim=256, dropout_p=0.5, n_layers=6, word2vec_matrix=None):
        super().__init__()
        # weight = torch.FloatTensor(np.diag([1]*5000))
        # self.embed= nn.EmbeddingBag.from_pretrained(weight)
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
        self.word2vec_matrix=word2vec_matrix
        self.layers = nn.Sequential(
            hidden_block(dropout_p, True, embed_dim, h_dim),
            *[ResBlock(dropout_p, h_dim) for i in range(n_layers)],
        )
        
        self.lastlayer=hidden_block(dropout_p, True, h_dim, vocab_size)
        self.softmax = nn.Softmax(dim=1)

    def get_embedding(self, text, offsets):
        last_layer_output=self.lastlayer(self.layers(self.embed(text, offsets))).cpu() #shape (batchsize, vocab_size)
        x = torch.matmul(self.softmax(last_layer_output),torch.tensor(self.word2vec_matrix).float()) #word2vec_matrix (vocab_size,embed size)
        return(x)

    def forward(self, text, offsets):
        return self.lastlayer(self.layers(self.embed(text, offsets)))

class AttentionBlock(nn.Module):
    def __init__(self, h_dim=512, dropout_p=0.5, n_heads=1):
        super().__init__()
        self.attn = nn.MultiheadAttention(h_dim, n_heads, batch_first=True) # batch first?
        self.to_q = nn.Linear(h_dim, h_dim)
        self.to_k = nn.Linear(h_dim, h_dim)
        self.to_v = nn.Linear(h_dim, h_dim)
        self.ff = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
        )
        self.bn1 = nn.BatchNorm1d(h_dim, affine=False)
        self.bn2 = nn.BatchNorm1d(h_dim, affine=False)

    def forward(self, x):
        '''
        :param x: (batch_size, num_words, h_dim)
        :return: (batch_size, num_words, h_dim)
        '''
        x = x + self.attn(self.to_q(x), self.to_k(x), self.to_v(x))[0]
        x = self.bn1(torch.transpose(x,1,2))#self.bn1(x.view(x.size(1), x.size(2)))
        x = torch.transpose(x,1,2)
        x = x + self.ff(x)
        x = self.bn2(torch.transpose(x,1,2))
        return torch.transpose(x,1,2)#.view(1, x.size(0), x.size(1))

class PredictModel_attention(nn.Module):
    def __init__(self, vocab_size, embed_dim=4096, h_dim=512, representation_dim=4096, dropout_p=0.5, n_layers=3, init_weight=None):
        super().__init__()
        dropout = [dropout_p] * n_layers
        self.layers = nn.Sequential(
            hidden_block(dropout_p, False, embed_dim, h_dim),
            *[AttentionBlock(h_dim, dropout[i]) for i in range(n_layers)]
        )
        #self.embed = nn.EmbeddingBag(vocab_size, h_dim, mode='mean') output dim ()
        self.convert = nn.Embedding(vocab_size, embed_dim, _weight=init_weight)

        #self.tail = nn.Linear(h_dim, vocab_size)
        #self.after_attn = nn.Sequential(*[ResBlock(h_dim, dropout_p, True) for _ in range(2)])
        #self.tail = nn.Sequential(nn.Linear(h_dim, vocab_size), nn.ReLU(), nn.Linear(vocab_size, vocab_size))
        self.fc = hidden_block(dropout_p, False, h_dim, representation_dim)
        self.tail = hidden_block(dropout_p, False, representation_dim, vocab_size)

    def forward(self, x):
        x = self.layers(self.convert(x))
        #return self.tail(x[:,-1,:])
        #return self.tail(self.after_attn(torch.mean(x,1)))
        # return self.tail(torch.mean(x,1))
        return self.tail(self.fc(torch.mean(x,1)))
    def get_embedding(self, x):
        #  return nn.Softmax(dim=-1)(self.forward(x))
        x = self.layers(self.convert(x))
        return nn.Softmax(dim=-1)(self.fc(torch.mean(x,1)))
