import torch
import torch.nn as nn
import math
from typing import Tuple


class GatedActivation(nn.Module):
    gate: bool
    activate: bool

    def __init__(self, input_size, gate=True, activate=True, activation='ReLU'):
        super(GatedActivation, self).__init__()
        self.gater = nn.Sequential(nn.Linear(input_size, input_size),
                                   nn.Sigmoid())
        self.activation = getattr(nn, activation)()
        self.gate = gate
        self.activate = activate

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # Implement gating
        if self.gate:
            # Get the gate of shape mn(kv)
            gate = self.gater(input)
        else:
            gate = None
        if self.activate:
            outputs = self.activation(input)
        else:
            outputs = input
        if self.gate and gate is not None:
            outputs = gate * outputs
        return outputs


class InputAttention(nn.Module):
    embedding_size: int
    normalize: bool

    def __init__(self, input_size, num_rims, rim_hidden_size, output_size, num_heads, embedding_size,
                 activate=False, gate=False, normalize=False):
        super(InputAttention, self).__init__()
        assert (output_size % num_heads) == 0
        self.embedding_size = embedding_size
        self.normalize = normalize
        # noinspection PyArgumentList
        self.key_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(input_size, num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.query_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(num_rims, rim_hidden_size, num_heads,
                                                                          embedding_size)))
        # noinspection PyArgumentList
        self.value_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(input_size, num_heads,
                                                                          output_size // num_heads)))
        # Normalizations
        self.key_norm = nn.LayerNorm([input_size, num_heads, embedding_size])
        self.query_norm = nn.LayerNorm([num_rims, num_heads, embedding_size])
        self.value_norm = nn.LayerNorm([input_size, num_heads, output_size // num_heads])
        # Output gating
        self.actigator = GatedActivation(output_size, gate=gate, activate=activate, activation='ReLU')

    def forward(self, input: torch.Tensor, rims_hidden_states: torch.Tensor) -> \
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Legend:
        #   n: batch
        #   r: also batch
        #   m: rim index
        #   i: input index
        #   h: hidden state index
        #   d: embedding index
        #   v: value index
        #   k: head index
        # We have:
        #   input.shape: ni
        #   rims_hidden_states.shape: mnh
        keys = torch.einsum('ni,ikd->nikd', [input, self.key_weights])
        queries = torch.einsum('mnh,mhkd->nmkd', [rims_hidden_states, self.query_weights])
        values = torch.einsum('ni,ikv->nikv', [input, self.value_weights])
        if self.normalize:
            keys = self.key_norm(keys)
            queries = self.query_norm(queries)
            values = self.value_norm(values)
        weights = torch.softmax(torch.einsum('nmkd,nikd->mnki', [queries, keys])/math.sqrt(self.embedding_size), dim=-1)
        outputs = torch.einsum('mnki,nikv->mnkv', [weights, values])
        # Recall that k * v = output_size
        _m, _n, _k, _v = outputs.shape
        outputs = outputs.reshape(_m, _n, _k * _v)
        # Activate
        outputs = self.actigator(outputs)
        return outputs, weights, values, queries, keys


class InterimAttention(nn.Module):
    embedding_size: int
    normalize: bool
    detach_inactive_rims: bool

    def __init__(self, rim_hidden_size, num_rims, num_heads, embedding_size, activate=False, gate=False,
                 normalize=False, detach_inactive_rims=True):
        super(InterimAttention, self).__init__()
        assert (rim_hidden_size % num_heads) == 0
        self.embedding_size = embedding_size
        self.normalize = normalize
        self.detach_inactive_rims = detach_inactive_rims
        # noinspection PyArgumentList
        self.key_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(num_rims, rim_hidden_size,
                                                                        num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.query_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(num_rims, rim_hidden_size,
                                                                          num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.value_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(num_rims, rim_hidden_size,
                                                                          num_heads, rim_hidden_size // num_heads)))
        # Normalizers
        self.key_norm = nn.LayerNorm([num_rims, num_rims, num_heads, embedding_size])
        self.query_norm = nn.LayerNorm([num_rims, num_rims, num_heads, embedding_size])
        self.value_norm = nn.LayerNorm([num_rims, num_rims, num_heads, rim_hidden_size // num_heads])
        # Output gating
        self.actigator = GatedActivation(rim_hidden_size, gate=gate, activate=activate, activation='Tanh')

    def forward(self, rims_hidden_states: torch.Tensor, active_rim_mask: torch.Tensor) -> \
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Legend:
        #   n: batch
        #   m: rim index
        #   l: also rim index
        #   h: hidden state index
        #   d: embedding index
        #   v: value index
        #   k: head index
        # We have:
        #   rims_hidden_states.shape = mnh
        #   active_rim_mask.shape = mn
        # Before we compute attention, we block the gradients through inactive RIMs with the detach trick (if required)
        if self.detach_inactive_rims:
            # noinspection PyTypeChecker
            rims_hidden_states = ((active_rim_mask[:, :, None] * rims_hidden_states) +
                                  ((1 - active_rim_mask[:, :, None]) * rims_hidden_states).detach())
        keys = torch.einsum('mnh,lhkd->nmlkd', [rims_hidden_states, self.key_weights])
        queries = torch.einsum('mnh,lhkd->nmlkd', [rims_hidden_states, self.query_weights])
        values = torch.einsum('mnh,lhkv->nmlkv', [rims_hidden_states, self.value_weights])
        if self.normalize:
            keys = self.key_norm(keys)
            queries = self.query_norm(queries)
            values = self.value_norm(values)
        weights = torch.softmax(torch.einsum('nmlkd,nmlkd->mnkl', [keys, queries])/math.sqrt(self.embedding_size),
                                dim=-1)
        outputs = torch.einsum('mnkl,nmlkv->mnkv', [weights, values])
        _m, _n, _k, _v = outputs.shape
        outputs = outputs.reshape(_m, _n, _k * _v)
        # Activate if required
        outputs = self.actigator(outputs)
        return outputs, weights, values, queries, keys


class RIMSelector(nn.Module):
    k: int

    def __init__(self, k):
        super(RIMSelector, self).__init__()
        self.k = k

    def forward(self, input_attn_weights: torch.Tensor) -> torch.Tensor:
        # input_attn_weights.shape = mnki, normalized along i. Sum over heads (index k) and pick the bottom
        # k along the last axis.
        # This results in a tensor of shape kn that contains int indices of active RIMs.
        input_attn_weights = input_attn_weights.sum(-2)
        indices = torch.topk(input_attn_weights[:, :, -1], self.k, dim=0, largest=False).indices
        # Now we must construct the mask, which is a mn tensor. If mask[m, n] = 1, then RIM m is active in batch n.
        mask = (torch.zeros(input_attn_weights.shape[0], input_attn_weights.shape[1],
                            dtype=torch.float, device=input_attn_weights.device)
                .scatter_(0, indices, 1.))
        return mask


class MultiRIMCell(nn.Module):
    def __init__(self, input_size, num_rims, rim_hidden_size, rim_input_size, active_set_size, num_input_attn_heads,
                 num_interim_attn_heads, attn_embedding_size, input_attn_gate=False, input_attn_activate=False,
                 input_attn_normalize=False, interim_attn_gate=False, interim_attn_activate=False,
                 interim_attn_normalize=False, interim_attn_detach_inactive_rims=True):
        super(MultiRIMCell, self).__init__()
        self.grus = nn.ModuleList([nn.GRUCell(rim_input_size, rim_hidden_size) for _ in range(num_rims)])
        # We have input_size + 1 due to the fact that we concatenate a 0 vector to the input.
        self.input_attn = InputAttention(input_size + 1, num_rims, rim_hidden_size,
                                         rim_input_size, num_input_attn_heads, attn_embedding_size,
                                         gate=input_attn_gate, activate=input_attn_activate,
                                         normalize=input_attn_normalize)
        self.interim_attn = InterimAttention(rim_hidden_size, num_rims,
                                             num_interim_attn_heads, attn_embedding_size,
                                             gate=interim_attn_gate, activate=interim_attn_activate,
                                             normalize=interim_attn_normalize,
                                             detach_inactive_rims=interim_attn_detach_inactive_rims)
        self.rim_selector = RIMSelector(active_set_size)

    def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        # Append a zero to x
        x = torch.cat([x, torch.zeros(x.shape[0], 1, dtype=x.dtype, device=x.device)], dim=1)
        # Attend to inputs
        rim_inputs, input_attn_weights, input_attn_values, input_attn_queries, input_attn_keys = self.input_attn(x, h)
        # Get active set, which is a mn tensor. If active_rim_mask[m, n] = 1, it means RIM m is active in batch n.
        active_rim_mask = self.rim_selector(input_attn_weights)
        # Get contribution from inter-rim interaction
        interim_h, _, _, _, _ = self.interim_attn(h, active_rim_mask)
        # Get next hidden state
        next_hs = []
        rim_idx = 0
        for gru in self.grus:
            next_hs.append(gru(rim_inputs[rim_idx], h[rim_idx]))
            rim_idx += 1
        active_h = torch.stack(next_hs, dim=0) + interim_h
        # Implement default dynamics
        h = active_rim_mask[:, :, None] * active_h + (1 - active_rim_mask[:, :, None]) * h
        return h


class _MultiRIM(nn.Module):
    def __init__(self, input_size, num_rims, rim_hidden_size, rim_input_size, active_set_size,
                 num_input_attn_heads, num_interim_attn_heads, attn_embedding_size, input_attn_gate=False,
                 input_attn_activate=False, input_attn_normalize=False, interim_attn_gate=False,
                 interim_attn_activate=False, interim_attn_normalize=False,
                 interim_attn_detach_inactive_rims=True):
        super(_MultiRIM, self).__init__()
        self.cell = MultiRIMCell(input_size, num_rims, rim_hidden_size, rim_input_size,
                                 active_set_size, num_input_attn_heads, num_interim_attn_heads,
                                 attn_embedding_size, interim_attn_gate=interim_attn_gate,
                                 interim_attn_activate=interim_attn_activate,
                                 interim_attn_normalize=interim_attn_normalize,
                                 input_attn_gate=input_attn_gate, input_attn_activate=input_attn_activate,
                                 input_attn_normalize=input_attn_normalize,
                                 interim_attn_detach_inactive_rims=interim_attn_detach_inactive_rims)

    def forward(self, input: torch.Tensor, hx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # input.shape = tni
        # hx.shape = mnh
        # output.shape = tn(m*h)
        outputs = []
        for t in range(input.shape[0]):
            hx = self.cell(input[t], hx)
            outputs.append(hx)
        # Convert output to a tmnh tensor
        output = torch.stack(outputs)
        # Transpose to tnmh, and fold to tn(mh)
        output = torch.einsum('tmnh->tnmh', [output]).reshape(input.shape[0], input.shape[1], hx.shape[0] * hx.shape[2])
        return output, hx


class MultiRIM(nn.Module):
    JIT = False

    def __init__(self, input_size, num_rims, rim_hidden_size, rim_input_size, active_set_size, attn_embedding_size,
                 num_input_attn_heads=1, num_interim_attn_heads=1, input_attn_gate=False, input_attn_activate=False,
                 input_attn_normalize=False, interim_attn_gate=False, interim_attn_activate=False,
                 interim_attn_normalize=False, interim_attn_detach_inactive_rims=True, batch_first=False):
        super(MultiRIM, self).__init__()
        self.input_size = input_size
        self.num_rims = num_rims
        self.rim_hidden_size = rim_hidden_size
        self.rim_input_size = rim_input_size
        self.active_set_size = active_set_size
        self.attn_embedding_size = attn_embedding_size
        self.batch_first = batch_first
        _multi_rim = _MultiRIM(input_size, num_rims, rim_hidden_size, rim_input_size,
                               active_set_size, num_input_attn_heads, num_interim_attn_heads,
                               attn_embedding_size, interim_attn_gate=interim_attn_gate,
                               interim_attn_activate=interim_attn_activate,
                               interim_attn_normalize=interim_attn_normalize,
                               input_attn_gate=input_attn_gate, input_attn_activate=input_attn_activate,
                               input_attn_normalize=input_attn_normalize,
                               interim_attn_detach_inactive_rims=interim_attn_detach_inactive_rims)
        if self.JIT:
            self._multi_rim = torch.jit.script(_multi_rim)
        else:
            self._multi_rim = _multi_rim

    def forward(self, input: torch.Tensor, hx: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        # input.shape = ntc or tnc
        # hx.shape = mnh
        if hx is None:
            hx = torch.zeros(self.num_rims, input.shape[(0 if self.batch_first else 1)], self.rim_hidden_size,
                             dtype=input.dtype, device=input.device)
        if self.batch_first:
            # Convert to time first
            input = torch.einsum('ntc->tnc', [input])
        # apply RNN
        output, hx = self._multi_rim(input, hx)
        if self.batch_first:
            # Output is tnc, convert to ntc
            output = torch.einsum('tnc->ntc', [output])
        return output, hx


if __name__ == '__main__':
    import torch.jit as jit
    import time

    # m = MultiRIMCell(10, 2, 32, 8, 2, 10)
    # ms = jit.script(m)
    # ms = m
    MultiRIM.JIT = False
    if torch.cuda.is_available():
        dev = 'cuda'
    else:
        dev = 'cpu'

    rnn = MultiRIM(input_size=144, num_rims=5, rim_hidden_size=64, rim_input_size=288,
                   active_set_size=3, attn_embedding_size=32, num_input_attn_heads=2,
                   num_interim_attn_heads=4, interim_attn_gate=True, interim_attn_activate=True,
                   input_attn_normalize=True, interim_attn_normalize=True,
                   interim_attn_detach_inactive_rims=False).to(dev)
    print(sum([p.numel() for p in rnn.parameters()]))
    # rnn = nn.LSTM(144, 292)
    # print(sum([p.numel() for p in rnn.parameters()]))

    # x <- tni
    x = torch.rand(128, 32, 144).to(dev)

    print(dev)
    N = 1
    tic = time.time()
    for _ in range(N):
        y, hx = rnn(x)
        y.sum().backward()
    toc = time.time()
    print((toc - tic)/N)

    # a = InputAttention(10, 2, 32, 8, 10)
    # as_ = jit.script(a)

    # r = RIMSelector(2)
    # rs = jit.script(r)

    # keys = torch.einsum('mhkd,mnh->mnkd', [self.key_weights, rims_hidden_states])
    # queries = torch.einsum('mhkd,mnh->mnkd', [self.query_weights, rims_hidden_states])
    # values = torch.einsum('mhkv,mnh->mnkv', [self.value_weights, rims_hidden_states])
    # weights = torch.softmax(torch.einsum('mnkd,mrkd->mnrk', [queries, keys])/math.sqrt(self.key_weights.shape[-1]),
    #                         dim=-1)
    # outputs = torch.einsum('mnr,mrv->mnv', [weights, values])
    # Done
    # return outputs, weights, values, queries, keys
    pass


