import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence

PAD_CHAR = 0
NUM_CHARS = 256


class CNN(nn.Module):
    def __init__(self, msg_hdim):
        super().__init__()
        self.msg_hdim = msg_hdim

        self.conv_fc = nn.Sequential(
            nn.Conv1d(NUM_CHARS, self.msg_hdim, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3),
            # conv2
            nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3),
            # conv3
            nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
            nn.ReLU(),
            # conv4
            nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
            nn.ReLU(),
            # conv5
            nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
            nn.ReLU(),
            # conv6
            nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3),
            # fc receives -- [ B x h_dim x 5 ]
            nn.Flatten(),
            nn.Linear(5 * self.msg_hdim, 2 * self.msg_hdim),
            nn.ReLU(),
            nn.Linear(2 * self.msg_hdim, self.msg_hdim),
        )  # final output -- [ B x h_dim x 5 ]

    def forward(self, messages, messages_len=None):
        char_emb = F.one_hot(messages, num_classes=NUM_CHARS).transpose(
            1, 2
        )
        char_rep = self.conv_fc(char_emb)
        return char_rep


class LtCNN(CNN):
    def __init__(self, msg_hdim, msg_edim):
        super().__init__(msg_hdim)
        self.msg_edim = msg_edim

        self.char_lt = nn.Embedding(
            NUM_CHARS,
            self.msg_edim,
            padding_idx=PAD_CHAR,
        )
        # Make first layer compatible with emb dim
        self.conv_fc[0] = nn.Conv1d(self.msg_edim, self.msg_hdim, kernel_size=7)

    def forward(self, messages, messages_len=None):
        char_emb = self.char_lt(messages).transpose(1, 2)
        char_rep = self.conv_fc(char_emb)
        return char_rep


class WordGRU(nn.Module):
    def __init__(self, msg_hdim, msg_edim, vocab_size=30522):
        super().__init__()
        self.msg_hdim = msg_hdim
        self.msg_edim = msg_edim

        self.emb = nn.Embedding(vocab_size, self.msg_edim, padding_idx=0)
        self.rnn = nn.GRU(self.msg_edim, self.msg_hdim, batch_first=True)

    def forward(self, messages, messages_len):
        messages_emb = self.emb(messages)
        packed_input = pack_padded_sequence(
            messages_emb, messages_len.cpu(), enforce_sorted=False, batch_first=True
        )
        _, hidden = self.rnn(packed_input)
        return hidden[0]


class BoC(nn.Module):
    """Bag of characters embedding for messages, super dumb."""
    def __init__(self, msg_hdim):
        super().__init__()
        self.msg_hdim = msg_hdim
        self.char_lt = nn.Embedding(NUM_CHARS, self.msg_hdim, padding_idx=PAD_CHAR)

    def forward(self, messages, messages_len=None):
        mask = (messages != 0).float()
        char_emb = self.char_lt(messages)
        # Sum up non-masked embeddings
        char_rep = (char_emb * mask.unsqueeze(-1)).sum(-2)

        # Divide by number of non-masked embeddings, min 1
        mask_sum = mask.sum(-1, keepdim=True)
        char_rep = char_rep / torch.clamp(mask_sum, min=1.0)
        return char_rep
