import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn as nn


############################################################################

class SingleDataset(Dataset):
    def __init__(self, tokens, counts):
        self.data = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())])) for tok, count in zip(tokens, counts)]

    def __len__(self):
        return( len(self.data) )

    def __getitem__(self, idx):
        return(self.data[idx])

def generate_single_batch(batch):
    text = batch
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)

    return text.long(), offsets.long()


############################################################################



########### Contrastive datasets ###########################################
class ContrastiveDataset(Dataset):
    def __init__(self, tokens_1, counts_1, tokens_2, counts_2, labels):
        self.data = list(zip(labels, 
                            [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok, count)])) for tok, count in zip(tokens_1, counts_1)],
                            [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok, count)])) for tok, count in zip(tokens_2, counts_2)]))

    def __len__(self):
        return( len(self.data) )

    def __getitem__(self, idx):
        return(self.data[idx])


def generate_contrast_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text_1 = [entry[1] for entry in batch]
    offsets_1 = [0] + [len(entry) for entry in text_1]
    offsets_1 = torch.tensor(offsets_1[:-1]).cumsum(dim=0)
    text_1 = torch.cat(text_1)

    text_2 = [entry[2] for entry in batch]
    offsets_2 = [0] + [len(entry) for entry in text_2]
    offsets_2 = torch.tensor(offsets_2[:-1]).cumsum(dim=0)
    text_2 = torch.cat(text_2)

    return text_1.long(), offsets_1.long(), text_2.long(), offsets_2.long(), label.float()

############################################################################

########### Reference Datasets #############################################
class ReferenceDataset(Dataset):
    def __init__(self, tokens, counts, ref_tokens, ref_counts):
        self.data = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok.squeeze(), count.squeeze())])) for tok, count in zip(tokens, counts)]
        self.refs = [torch.from_numpy(np.concatenate([np.repeat(t, c) for t,c in zip(tok, count)])) for tok, count in zip(ref_tokens, ref_counts)]

    def __len__(self):
        return( len(self.data) * len(self.refs) )

    def __getitem__(self, idx):
        nrefs = len(self.refs)
        return(self.data[idx//nrefs], self.refs[idx%nrefs])

def generate_reference_batch(batch):
    text_1 = [entry[0] for entry in batch]
    offsets_1 = [0] + [len(entry) for entry in text_1]
    offsets_1 = torch.tensor(offsets_1[:-1]).cumsum(dim=0)
    text_1 = torch.cat(text_1)

    text_2 = [entry[1] for entry in batch]
    offsets_2 = [0] + [len(entry) for entry in text_2]
    offsets_2 = torch.tensor(offsets_2[:-1]).cumsum(dim=0)
    text_2 = torch.cat(text_2)

    return text_1.long(), offsets_1.long(), text_2.long(), offsets_2.long()


###############################
## Defines the type of hidden block we use
def hidden_block(dropout_p, in_size, out_size):
    return(
        nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.ReLU(),
            nn.BatchNorm1d(in_size),
            nn.Linear(in_size, out_size)
        )
    )

############################################################################
##  Contrastive model.
##    Input: A pair of documents (real or contrastive).
##    Output: A real number.

class ContrastiveModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, c_dim=100, h_dim=256, dropout_p=0.5, n_layers=3, fine_tune_vectors=True):
        super().__init__()
        self.embed_1 = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
        self.embed_2 = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')

        sizes = [embed_dim] + [h_dim]*(n_layers-1) + [c_dim]
        dropouts = [dropout_p for k in range(0, n_layers)]
        hidden_blocks_1 = nn.Sequential(*[hidden_block(dropouts[i], sizes[i], sizes[i+1]) for i in range(n_layers)])
        hidden_blocks_2 = nn.Sequential(*[hidden_block(dropouts[i], sizes[i], sizes[i+1]) for i in range(n_layers)])

        self.layers_1 = nn.Sequential(
            hidden_blocks_1,
            nn.BatchNorm1d(c_dim)
        )
        self.layers_2 = nn.Sequential(
            hidden_blocks_2,
            nn.BatchNorm1d(c_dim)
        )

    def get_embedding(self, text, offsets):
        x = self.layers_1(self.embed_1(text, offsets))
        return(x)
    
    def forward(self, text_1, offsets_1, text_2, offsets_2):
        x1 = self.layers_1(self.embed_1(text_1, offsets_1))
        x2 = self.layers_2(self.embed_2(text_2, offsets_2))
        return( (x1*x2).sum(1, keepdim=True))