import torch
import torch.nn as nn
import torch.nn.functional as F
from egru.egrud_experimental import ScriptEGRUD
from egru.modules import VariationalDropout, WeightDrop
from lm.embedding_dropout import embedded_dropout
from typing import Union


class Decoder(nn.Module):
    def __init__(self,
                 ninp: int,
                 ntokens: int,
                 project: bool = False,
                 nemb: Union[None, int] = None,
                 dropout: float = 0.0):
        """
        Takes hidden states of RNNs, optionally applies a projection operation and decodes to output tokens
        :param ninp: Input dimension
        :param ntokens: Number of tokens of the language model
        :param project: If True, applies a linear projection onto the embedding dimension
        :param nemb: If projection is True, specifies the dimension of the projection
        :param dropout: Dropout rate applied to the projector
        """
        super(Decoder, self).__init__()

        if project:
            assert nemb, "If projection is True, must specify nemb!"

        self.ninp = ninp
        self.nemb = nemb if nemb else ninp
        self.nout = ntokens

        self.dropout = dropout
        self.variational_dropout = VariationalDropout()

        # projector
        self.project = project
        if project:
            self.projection = nn.Linear(ninp, nemb)
        else:
            self.projection = nn.Identity()

        # word embedding decoder
        self.decoder = nn.Linear(self.nemb, self.nout)
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        bs, seq_len, ninp = x.shape
        if self.project:
            x = x.view(-1, ninp)
            x = F.relu(self.projection(x))
            x = x.view(bs, seq_len, self.nemb)
            x = self.variational_dropout(x, self.dropout)
        x = x.view(-1, self.nemb)
        x = self.decoder(x)
        return x


class LanguageModel(nn.Module):
    def __init__(self,
                 rnn_type,
                 nlayers,
                 emb_dim,
                 hidden_dim,
                 vocab_size,
                 projection,
                 dropout_words,
                 dropout_embedding,
                 dropout_connect,
                 dropout_forward,
                 alpha,
                 beta,
                 gamma,
                 **kwargs):
        super(LanguageModel, self).__init__()

        # language model specifics
        self.nlayers = nlayers
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.projection = projection

        # dropout initializations
        self.dropout_words = dropout_words
        self.dropout_embedding = dropout_embedding
        self.dropout_connect = dropout_connect
        self.dropout_forward = dropout_forward

        # activity regularization specification
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.ar_loss = 0

        # input and output layers
        self.variational_dropout = VariationalDropout()
        self.embeddings = nn.Embedding(vocab_size, emb_dim)
        initrange = 0.1
        self.init_embedding(initrange=initrange)

        self.decoder = Decoder(ninp=hidden_dim if projection else emb_dim, ntokens=vocab_size,
                               project=projection, nemb=emb_dim,
                               dropout=dropout_forward)

        # Tie weights of embedding and decoder
        self.decoder.decoder.weight = self.embeddings.weight

        # RNN model definition
        self.rnn_type = rnn_type
        self.use_output_trace = [False] * self.nlayers
        if rnn_type == 'lstm':
            self.rnns = [nn.LSTM(emb_dim if l == 0 else hidden_dim,
                                 emb_dim if l == nlayers - 1 and not projection else hidden_dim,
                                 num_layers=1, batch_first=True, dropout=0)
                         for l in range(nlayers)]
            if dropout_connect > 0:
                self.rnns = [WeightDrop(rnn, [f'weight_hh_l0'], dropout=dropout_connect) for rnn in self.rnns]
        elif rnn_type == 'gru':
            self.rnns = [nn.GRU(emb_dim if l == 0 else hidden_dim,
                                emb_dim if l == nlayers - 1 and not projection else hidden_dim,
                                num_layers=1, batch_first=True, dropout=0)
                         for l in range(nlayers)]
            if dropout_connect > 0:
                self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=dropout_connect) for rnn in self.rnns]
        elif rnn_type == 'egrud':
            self.rnns = [ScriptEGRUD(input_size=emb_dim if l == 0 else hidden_dim,
                                     hidden_size=emb_dim if l == nlayers - 1 and not projection else hidden_dim,
                                     **kwargs) for l in range(nlayers)]
            self.use_output_trace = [rnn.use_output_trace for rnn in self.rnns]
            if dropout_connect > 0:
                self.rnns = [WeightDrop(rnn, ['U'], dropout=dropout_connect) for rnn in self.rnns]
        else:
            raise NotImplementedError(f"Model '{rnn_type}' not implemented.")
        self.rnns = nn.ModuleList(self.rnns)

        self.backward_sparsity = torch.zeros(len(self.rnns))

    def init_embedding(self, initrange):
        nn.init.uniform_(self.embeddings.weight, -initrange, initrange)

    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        if self.rnn_type == 'lstm':
            return [(weight.new_zeros(1, batch_size, self.emb_dim if l == self.nlayers - 1 and not self.projection else self.hidden_dim),
                     weight.new_zeros(1, batch_size, self.emb_dim if l == self.nlayers - 1 and not self.projection else self.hidden_dim))
                    for l in range(self.nlayers)]

        elif self.rnn_type == 'gru':
            return [weight.new_zeros(1, batch_size, self.emb_dim if l == self.nlayers - 1 and not self.projection else self.hidden_dim)
                    for l in range(self.nlayers)]

        elif self.rnn_type == 'egrud':
            if self.dropout_connect > 0:
                return [rnn.module.init_hidden(batch_size) for rnn in self.rnns]
            else:
                return [rnn.init_hidden(batch_size) for rnn in self.rnns]

        else:
            raise NotImplementedError(f"Model '{self.rnn_type}' not implemented.")

    def forward(self, inputs, state):
        # reset activity regularization loss
        self.ar_loss = 0

        # embedding forward
        embedded = embedded_dropout(self.embeddings, inputs,
                                    dropout=self.dropout_words if self.training else 0)

        embedded = self.variational_dropout(embedded, dropout=self.dropout_embedding)

        # rnn forward
        new_states = []
        raw_hiddens = []
        dropped_hiddens = []
        hiddens = embedded
        for l, rnn in enumerate(self.rnns):
            hiddens, final_states = rnn(hiddens, state[l])

            raw_hiddens.append(hiddens)

            if self.rnn_type == 'egrud':
                c_tm, o_tm, i_tm, tr_tm = final_states

                self.ar_loss += self.activity_regularization(hidden_raw=c_tm, output_gates=o_tm)

                thresholds = rnn.module.thr if isinstance(rnn, WeightDrop) else rnn.thr
                epsilon = rnn.module.pseudo_derivative_width if isinstance(rnn, WeightDrop) else rnn.pseudo_derivative_width
                self.backward_sparsity[l] = torch.mean(
                    torch.logical_or((c_tm - thresholds) > 1 / epsilon,
                                     (c_tm - thresholds) < - 1 / epsilon).float())

                if self.use_output_trace[l]:
                    hiddens = torch.transpose(tr_tm.squeeze(), 0, 1)

                final_states = (c_tm[:, -1], o_tm[:, -1], i_tm[:, -1], tr_tm[:, -1])

            new_states.append(final_states)

            if l != self.nlayers - 1:
                hiddens = self.variational_dropout(hiddens, dropout=self.dropout_forward)
                dropped_hiddens.append(hiddens)

            if l == self.nlayers - 1 and self.rnn_type != 'egrud':
                self.ar_loss += self.activity_regularization(hiddens)

        # decoder forward
        hiddens_ = hiddens.contiguous()
        hiddens_ = self.variational_dropout(hiddens_)
        dropped_hiddens.append(hiddens_)

        decoded = self.decoder(hiddens_)
        return F.log_softmax(decoded, dim=1), new_states, raw_hiddens, dropped_hiddens

    def activity_regularization(self, hidden_raw, hidden_dropped=None, output_gates=None):
        ar, tar, gates = 0, 0, 0

        hidden_raw = hidden_raw.squeeze()
        hidden_dropped = hidden_raw.squeeze() if torch.is_tensor(hidden_dropped) else hidden_raw
        output_gates = output_gates.squeeze() if torch.is_tensor(output_gates) else output_gates

        # EGRU
        if self.rnn_type == 'egrud':

            # regularize activity to approach its minimum (-1)
            if self.alpha:
                state_reg = torch.mean(hidden_raw)
                ar = self.alpha * state_reg

            # regularize differences in states
            if self.beta:
                state_diff = hidden_raw[1:] - hidden_raw[:-1]
                mask = output_gates[:-1].bool()
                temporal_state_reg = torch.mean(state_diff[mask].pow(2))
                tar = self.beta * temporal_state_reg

            # regularize activity
            if self.gamma:
                activity_reg = torch.mean(output_gates)
                gates = self.gamma * activity_reg

        # LSTM, GRU etc
        else:
            hidden_dropped = hidden_dropped if torch.is_tensor(hidden_dropped) else hidden_raw
            if self.alpha:
                ar = self.alpha * hidden_dropped.pow(2).mean()
            if self.beta:
                tar = self.beta * (hidden_raw[:, 1:] - hidden_raw[:, :-1]).pow(2).mean()

        return ar + tar + gates
