import numpy as np
import math
import torch
from torch.utils.data import Dataset
import torch.nn as nn


BATCH_SIZE = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################################################################

class PredictWordDataset(Dataset):
    def __init__(self, tokens, counts, labels):
        all_docs_expanded = []
        for this_doc_tokens, this_doc_counts in zip(tokens, counts):
            this_doc = np.concatenate([np.repeat(t,c) for t,c in zip(this_doc_tokens, this_doc_counts)])
            all_docs_expanded.append(torch.from_numpy(this_doc))
        self.data = list(zip(all_docs_expanded, labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

def generate_batch(batch):
    labels = torch.tensor([entry[1] for entry in batch])
    texts = [entry[0] for entry in batch]
    texts = torch.cat(texts)
    offsets = [0]
    cumulative_length = 0
    for entry in batch:
        cumulative_length += len(entry[0])
        offsets.append(cumulative_length)
    offsets = torch.tensor(offsets[:-1])
    return texts.long(), offsets.long(), labels.long()

def generate_attn_batch(batch):
    labels = torch.tensor([entry[1] for entry in batch])
    texts = torch.tensor([entry[0] for entry in batch])
    return texts.long(), labels.long()

def hidden_block(dropout_p, batchnorm, in_size, out_size):
    if(batchnorm):
        return(
            nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.ReLU(),
                nn.BatchNorm1d(in_size),
                nn.Linear(in_size, out_size)
            )
        )
    else:
        return(
            nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.ReLU(),
                nn.Linear(in_size, out_size)
            )
        )


class PredictWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, h_dim=512, dropout_p=0.5, n_layers=3, vectors=None,
                 fine_tune_vectors=True, optimize_embedding=True):
        super().__init__()
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean', _weight=vectors)
        self.embed.weight.requires_grad = optimize_embedding

        sizes = [embed_dim] + [h_dim] * (n_layers - 1) + [vocab_size]
        batch_norms = [True] + [False] * (n_layers - 1)
        #batch_norms = [True] * 3 + [False] * (n_layers - 3)
        dropouts = [dropout_p for k in range(0, n_layers)]

        self.layers = nn.Sequential(
            *[hidden_block(dropouts[i], batch_norms[i], sizes[i], sizes[i + 1]) for i in range(n_layers)])

    def get_word_probability(self, text, offsets):
        x = self.layers(self.embed(text, offsets))
        return nn.Softmax(dim=-1)(x)

    def forward(self, text, offsets):
        x = self.layers(self.embed(text, offsets))
        return x

class BaselineModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, vectors=None, optimize_embedding=True):
        super().__init__()
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean', _weight=vectors)
        self.embed.weight.requires_grad = optimize_embedding
        self.linear = nn.Linear(embed_dim, vocab_size)

    def forward(self, text, offsets):
        return self.linear(self.embed(text, offsets))

    def get_word_probability(self, text, offsets):
        x = self.linear(self.embed(text, offsets))
        return nn.Softmax(dim=-1)(x)

def softmax_block(dropout_p, in_size, out_size, use_softmax):
    if use_softmax:
        return nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.Linear(in_size, in_size),
                nn.ReLU(),
                nn.Linear(in_size, out_size),
                nn.Softmax(dim=-1)
                )
    else:
        return nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.Linear(in_size, in_size),
                nn.ReLU(),
                nn.Linear(in_size, out_size),
                )

class SoftmaxBlockModel(nn.Module):
    def __init__(self, vocab_size, h_dim=512, dropout_p=0.5, n_layers=3):
        super().__init__()
        self.vocab_size = vocab_size
        dropouts = [dropout_p]*n_layers
        softmaxs = [True]*(n_layers-1) + [False]
        sizes = [vocab_size] + [h_dim]*(n_layers-1) + [vocab_size]
        self.layers = nn.Sequential(
            *[softmax_block(dropouts[i], sizes[i], sizes[i+1], softmaxs[i]) for i in range(n_layers)]
        )

        id_mat = np.zeros((vocab_size, vocab_size))
        for i in range(vocab_size):
            id_mat[i,i] = 1
        id_mat = torch.tensor(id_mat, dtype=torch.float32)
        self.convert = nn.EmbeddingBag(vocab_size, vocab_size, mode='mean', _weight=id_mat)
        self.convert.weight.requires_grad = False

    def forward(self, text, offsets):
        x = self.convert(text, offsets)
        return self.layers(x)

    def get_word_probability(self, text, offsets):
        return nn.Softmax(dim=-1)(self.forward(text, offsets))


class ResBlock(nn.Module):
    def __init__(self, h_dim, dropout_p, batchnorm):
        super().__init__()
        self.use_batchnorm = batchnorm
        self.bnlayer = nn.BatchNorm1d(h_dim, affine=False)
        self.components = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            #nn.Linear(hdim, data_size)
        )
    def forward(self, x):
        x = self.components(x) + x
        if self.use_batchnorm:
            return self.bnlayer(x)
        return x

class ResBlockModel(nn.Module):
    def __init__(self, vocab_size, h_dim=512, dropout_p=0.5, n_layers=3):
        super().__init__()
        self.vocab_size = vocab_size
        self.head = nn.Linear(vocab_size, h_dim)
        self.tail = nn.Linear(h_dim, vocab_size)
        dropouts = [dropout_p]*n_layers
        bn = [True]*(n_layers)
        self.layers = nn.Sequential(
            *[ResBlock(h_dim, dropouts[i], bn[i]) for i in range(n_layers)]
        )

        id_mat = np.zeros((vocab_size, vocab_size))
        for i in range(vocab_size):
            id_mat[i,i] = 1
        id_mat = torch.tensor(id_mat, dtype=torch.float32)
        self.convert = nn.EmbeddingBag(vocab_size, vocab_size, mode='mean', _weight=id_mat)
        self.convert.weight.requires_grad = False

    def forward(self, text, offsets):
        x = self.convert(text, offsets)
        x = self.head(x)
        return self.tail(self.layers(x))

    def get_word_probability(self, text, offsets):
        return nn.Softmax(dim=-1)(self.forward(text, offsets))

class ResSoftmaxBlock(nn.Module):
    def __init__(self, h_dim, dropout_p, soft):
        super().__init__()
        self.use_softmax = soft
        self.softmax = nn.Softmax(dim=-1)
        self.components = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            #nn.Linear(hdim, data_size)
        )
    def forward(self, x):
        x = self.components(x) + x
        if self.use_softmax:
            return self.softmax(x)
        return x

class ResSoftmaxModel(nn.Module):
    def __init__(self, vocab_size, h_dim=512, dropout_p=0.5, n_layers=3):
        super().__init__()
        self.vocab_size = vocab_size
        self.head = nn.Linear(vocab_size, h_dim)
        self.tail = nn.Linear(h_dim, vocab_size)
        dropouts = [dropout_p]*n_layers
        bn = [True]*(n_layers)
        self.layers = nn.Sequential(
            *[ResSoftmaxBlock(h_dim, dropouts[i], bn[i]) for i in range(n_layers)]
        )

        id_mat = np.zeros((vocab_size, vocab_size))
        for i in range(vocab_size):
            id_mat[i,i] = 1
        id_mat = torch.tensor(id_mat, dtype=torch.float32)
        self.convert = nn.EmbeddingBag(vocab_size, vocab_size, mode='mean', _weight=id_mat)
        self.convert.weight.requires_grad = False

    def forward(self, text, offsets):
        x = self.convert(text, offsets)
        x = self.head(x)
        return self.tail(self.layers(x))

    def get_word_probability(self, text, offsets):
        return nn.Softmax(dim=-1)(self.forward(text, offsets))

class SimpleAttentionBlock(nn.Module):
    def __init__(self, h_dim=512, dropout_p=0.5, n_heads=1):
        super().__init__()
        self.attn = nn.MultiheadAttention(h_dim, n_heads, batch_first=True) # batch first?
        self.ff = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
        )
        self.bn1 = nn.BatchNorm1d(h_dim, affine=False)
        self.bn2 = nn.BatchNorm1d(h_dim, affine=False)

    def forward(self, x):
        '''
        :param x: (batch_size, num_words, h_dim)
        :return: (batch_size, num_words, h_dim)
        '''
        x = x + self.attn(x,x,x)[0]
        x = self.bn1(torch.transpose(x,1,2))#self.bn1(x.view(x.size(1), x.size(2)))
        x = torch.transpose(x,1,2)
        x = x + self.ff(x)
        x = self.bn2(torch.transpose(x,1,2))
        return torch.transpose(x,1,2)#.view(1, x.size(0), x.size(1))

class AttentionBlock(nn.Module):
    def __init__(self, h_dim=512, dropout_p=0.5, n_heads=1):
        super().__init__()
        self.attn = nn.MultiheadAttention(h_dim, n_heads, batch_first=True) # batch first?
        self.to_q = nn.Linear(h_dim, h_dim)
        self.to_k = nn.Linear(h_dim, h_dim)
        self.to_v = nn.Linear(h_dim, h_dim)
        self.ff = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
        )
        self.bn1 = nn.BatchNorm1d(h_dim, affine=False)
        self.bn2 = nn.BatchNorm1d(h_dim, affine=False)

    def forward(self, x):
        '''
        :param x: (batch_size, num_words, h_dim)
        :return: (batch_size, num_words, h_dim)
        '''
        x = x + self.attn(self.to_q(x), self.to_k(x), self.to_v(x))[0]
        x = self.bn1(torch.transpose(x,1,2))#self.bn1(x.view(x.size(1), x.size(2)))
        x = torch.transpose(x,1,2)
        x = x + self.ff(x)
        x = self.bn2(torch.transpose(x,1,2))
        return torch.transpose(x,1,2)#.view(1, x.size(0), x.size(1))

class AttentionModel(nn.Module):
    def __init__(self, vocab_size, h_dim=512, dropout_p=0.5, n_layers=3, init_weight=None):
        super().__init__()
        dropout = [dropout_p] * n_layers
        self.layers = nn.Sequential(
            *[AttentionBlock(h_dim, dropout[i]) for i in range(n_layers)]
        )
        self.convert = nn.Embedding(vocab_size, h_dim, _weight=init_weight)
        #self.tail = nn.Linear(h_dim, vocab_size)
        #self.after_attn = nn.Sequential(*[ResBlock(h_dim, dropout_p, True) for _ in range(2)])
        self.tail = nn.Sequential(nn.Linear(h_dim, vocab_size), nn.ReLU(), nn.Linear(vocab_size, vocab_size))

    def forward(self, x):
        x = self.layers(self.convert(x))
        #return self.tail(x[:,-1,:])
        #return self.tail(self.after_attn(torch.mean(x,1)))
        return self.tail(torch.mean(x,1))

    def get_word_probability(self, x):
        return nn.Softmax(dim=-1)(self.forward(x))

class AttnCTMDataset(Dataset):
    def __init__(self, docs, labels):
        self.data = list(zip(docs, labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]
