import numpy as np
import math
import torch
from torch.utils.data import Dataset
import torch.nn as nn


BATCH_SIZE = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################################################################

class PredictWordDataset(Dataset):
    def __init__(self, tokens, counts, labels):
        all_docs_expanded = []
        for this_doc_tokens, this_doc_counts in zip(tokens, counts):
            this_doc = np.concatenate([np.repeat(t,c) for t,c in zip(this_doc_tokens, this_doc_counts)])
            all_docs_expanded.append(torch.from_numpy(this_doc))
        self.data = list(zip(all_docs_expanded, labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

def generate_batch(batch):
    labels = torch.tensor([entry[1] for entry in batch])
    texts = [entry[0] for entry in batch]
    texts = torch.cat(texts)
    offsets = [0]
    cumulative_length = 0
    for entry in batch:
        cumulative_length += len(entry[0])
        offsets.append(cumulative_length)
    offsets = torch.tensor(offsets[:-1])
    return texts.long(), offsets.long(), labels.long()


def hidden_block(dropout_p, batchnorm, in_size, out_size):
    if(batchnorm):
        return(
            nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.ReLU(),
                nn.BatchNorm1d(in_size),
                nn.Linear(in_size, out_size)
            )
        )
    else:
        return(
            nn.Sequential(
                nn.Dropout(p=dropout_p),
                nn.ReLU(),
                nn.Linear(in_size, out_size)
            )
        )


class PredictWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, h_dim=512, dropout_p=0.5, n_layers=3, vectors=None,
                 fine_tune_vectors=True):
        super().__init__()
        self.embed = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean', _weight=vectors)

        sizes = [embed_dim] + [h_dim] * (n_layers - 1) + [vocab_size]
        batch_norms = [True] + [False] * (n_layers - 1)
        dropouts = [dropout_p for k in range(0, n_layers)]

        self.layers = nn.Sequential(
            *[hidden_block(dropouts[i], batch_norms[i], sizes[i], sizes[i + 1]) for i in range(n_layers)])

    def get_word_probability(self, text, offsets):
        x = self.layers(self.embed(text, offsets))
        return nn.Softmax(dim=-1)(x)

    def forward(self, text, offsets):
        x = self.layers(self.embed(text, offsets))
        return x
