import torch
from torch import nn
import torch.nn.functional as F

# this class largely follows the official sonnet implementation
# https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/relational_memory.py


class RelationalMemory(nn.Module):
    """
    Constructs a `RelationalMemory` object.
    Args:
      input_size: The size of input per step. i.e. the dimension of each input vector
      mem_slots: The total number of memory slots to use.
      head_size: The size of an attention head.
      num_heads: The number of attention heads to use. Defaults to 1.
      num_blocks: Number of times to compute attention per time step. Defaults
        to 1.
      forget_bias: Bias to use for the forget gate, assuming we are using
        some form of gating. Defaults to 1.
      input_bias: Bias to use for the input gate, assuming we are using
        some form of gating. Defaults to 0.
      gate_style: Whether to use per-element gating ('unit'),
        per-memory slot gating ('memory'), or no gating at all (None).
        Defaults to `unit`.
      attention_mlp_layers: Number of layers to use in the post-attention
        MLP. Defaults to 2.
      key_size: Size of vector to use for key & query vectors in the attention
        computation. Defaults to None, in which case we use `head_size`.
    Raises:
      ValueError: gate_style not one of [None, 'memory', 'unit'].
      ValueError: num_blocks is < 1.
      ValueError: attention_mlp_layers is < 1.
    """

    def __init__(self, input_size, mem_slots, head_size, num_heads=1, num_blocks=1, forget_bias=1.,
                 input_bias=0., gate_style='unit', attention_mlp_layers=2, key_size=None, _num_input_slots=1):
        super(RelationalMemory, self).__init__()

        ########## generic parameters for RMC ##########
        self.mem_slots = mem_slots
        self.head_size = head_size
        self.num_heads = num_heads
        self.mem_size = self.head_size * self.num_heads
        # Batch must come first
        self.batch_first = True

        # a new fixed params needed for pytorch port of RMC
        # +1 is the concatenated input per time step : we do self-attention with the concatenated memory & input
        # so if the mem_slots = 1, this value is 2
        self._num_input_slots = _num_input_slots
        self.mem_slots_plus_input = self.mem_slots + self._num_input_slots

        if num_blocks < 1:
            raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks))
        self.num_blocks = num_blocks

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. got: '
                '{}.'.format(gate_style))
        self.gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self.attention_mlp_layers = attention_mlp_layers

        self.key_size = key_size if key_size else self.head_size

        ########## parameters for multihead attention ##########
        # value_size is same as head_size
        self.value_size = self.head_size
        # total size for query-key-value
        self.qkv_size = 2 * self.key_size + self.value_size
        self.total_qkv_size = self.qkv_size * self.num_heads  # denoted as F

        # each head has qkv_sized linear projector
        # just using one big param is more efficient, rather than this line
        # self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)]
        self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size)
        self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size])

        # used for attend_over_memory function
        self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers)
        self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
        self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])

        ########## parameters for initial embedded input projection ##########
        self.input_size = input_size
        self.input_projector = nn.Linear(self.input_size, self.mem_size)

        ########## parameters for gating ##########
        self.num_gates = 2 * self.calculate_gate_size()
        self.input_gate_projector = nn.Linear(self.mem_size, self.num_gates)
        self.memory_gate_projector = nn.Linear(self.mem_size, self.num_gates)
        # trainable scalar gate bias tensors
        # noinspection PyArgumentList
        self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32))
        # noinspection PyArgumentList
        self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32))

        # Compute capacity for model
        self.hidden_size = self.mem_slots * self.head_size * self.num_heads

    def repackage_hidden(self, h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        # needed for truncated BPTT, called at every batch forward pass
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)

    def initial_state(self, batch_size, trainable=False, packed=False):
        """
        Creates the initial memory.
        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size
        (batch_size, self.mem_slots, self.mem_size).
        Args:
          batch_size: The size of the batch.
          trainable: Whether the initial state is trainable. This is always True.
          packed: Whether the initial states are flattened as (batch_size, mem_slots * mem_size)
        Returns:
          init_state: A truncated or padded matrix of size
            (batch_size, self.mem_slots, self.mem_size).
        """
        init_state = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)])

        # pad the matrix with zeros
        if self.mem_size > self.mem_slots:
            difference = self.mem_size - self.mem_slots
            pad = torch.zeros((batch_size, self.mem_slots, difference))
            init_state = torch.cat([init_state, pad], -1)

        # truncation. take the first 'self.mem_size' components
        elif self.mem_size < self.mem_slots:
            init_state = init_state[:, :, :self.mem_size]

        if packed:
            init_state = init_state.reshape(batch_size, self.mem_slots * self.mem_size)
        return init_state

    def multihead_attention(self, memory):
        """
        Perform multi-head attention from 'Attention is All You Need'.
        Implementation of the attention mechanism from
        https://arxiv.org/abs/1706.03762.
        Args:
          memory: Memory tensor to perform attention on.
        Returns:
          new_memory: New memory tensor.
        """

        # First, a simple linear projection is used to construct queries
        qkv = self.qkv_projector(memory)
        # apply layernorm for every dim except the batch dim
        qkv = self.qkv_layernorm(qkv)

        # mem_slots needs to be dynamically computed since mem_slots got concatenated with inputs
        # example: self.mem_slots=10 and seq_length is 3, and then mem_slots is 10 + 1 = 11 for each 3 step forward pass
        # this is the same as self.mem_slots_plus_input, but defined to keep the sonnet implementation code style
        mem_slots = memory.shape[1]  # denoted as N

        # split the qkv to multiple heads H
        # [B, N, F] => [B, N, H, F/H]
        qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size)

        # [B, N, H, F/H] => [B, H, N, F/H]
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)

        # [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size]
        q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1)

        # scale q with d_k, the dimensionality of the key vectors
        q *= (self.key_size ** -0.5)

        # make it [B, H, N, N]
        dot_product = torch.matmul(q, k.permute(0, 1, 3, 2))
        weights = F.softmax(dot_product, dim=-1)

        # output is [B, H, N, V]
        output = torch.matmul(weights, v)

        # [B, H, N, V] => [B, N, H, V] => [B, N, H*V]
        output_transpose = output.permute(0, 2, 1, 3).contiguous()
        new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1))

        return new_memory

    @property
    def state_size(self):
        return [self.mem_slots, self.mem_size]

    @property
    def output_size(self):
        return self.mem_slots * self.mem_size

    def calculate_gate_size(self):
        """
        Calculate the gate size from the gate_style.
        Returns:
          The per sample, per head parameter size of each gate.
        """
        if self.gate_style == 'unit':
            return self.mem_size
        elif self.gate_style == 'memory':
            return 1
        else:  # self.gate_style == None
            return 0

    def create_gates(self, inputs, memory):
        """
        Create input and forget gates for this step using `inputs` and `memory`.
        Args:
          inputs: Tensor input.
          memory: The current state of memory.
        Returns:
          input_gate: A LSTM-like insert gate.
          forget_gate: A LSTM-like forget gate.
        """
        # We'll create the input and forget gates at once. Hence, calculate double
        # the gate size.

        # equation 8: since there is no output gate, h is just a tanh'ed m
        memory = torch.tanh(memory)

        # TODO: check this input flattening is correct
        # sonnet uses this, but i think it assumes time step of 1 for all cases
        # if inputs is (B, T, features) where T > 1, this gets incorrect
        # inputs = inputs.view(inputs.shape[0], -1)

        # fixed implementation
        if len(inputs.shape) == 3:
            if inputs.shape[1] > 1:
                raise ValueError(
                    "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
            inputs = inputs.view(inputs.shape[0], -1)
            # matmul for equation 4 and 5
            # there is no output gate, so equation 6 is not implemented
            gate_inputs = self.input_gate_projector(inputs)
            gate_inputs = gate_inputs.unsqueeze(dim=1)
            gate_memory = self.memory_gate_projector(memory)
        else:
            raise ValueError("input shape of create_gate function is 2, expects 3")

        # this completes the equation 4 and 5
        gates = gate_memory + gate_inputs
        gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
        input_gate, forget_gate = gates
        assert input_gate.shape[2] == forget_gate.shape[2]

        # to be used for equation 7
        input_gate = torch.sigmoid(input_gate + self.input_bias)
        forget_gate = torch.sigmoid(forget_gate + self.forget_bias)

        return input_gate, forget_gate

    def attend_over_memory(self, memory):
        """
        Perform multiheaded attention over `memory`.
            Args:
              memory: Current relational memory.
            Returns:
              The attended-over memory.
        """
        for _ in range(self.num_blocks):
            attended_memory = self.multihead_attention(memory)

            # Add a skip connection to the multiheaded attention's input.
            memory = self.attended_memory_layernorm(memory + attended_memory)

            # add a skip connection to the attention_mlp's input.
            attention_mlp = memory
            for i, l in enumerate(self.attention_mlp):
                attention_mlp = self.attention_mlp[i](attention_mlp)
                attention_mlp = F.relu(attention_mlp)
            memory = self.attended_memory_layernorm2(memory + attention_mlp)

        return memory

    def forward_step(self, inputs, memory, treat_input_as_matrix=False):
        """
        Forward step of the relational memory core.
        Args:
          inputs: Tensor input.
          memory: Memory output from the previous time step.
          treat_input_as_matrix: Optional, whether to treat `input` as a sequence
            of matrices. Default to False, in which case the input is flattened
            into a vector.
        Returns:
          output: This time step's output.
          next_memory: The next version of memory to use.
        """

        inputs_embed = inputs

        if treat_input_as_matrix:
            # keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2
            inputs_embed = inputs_embed.view(inputs_embed.shape[0], inputs_embed.shape[1], -1)
            # apply linear layer for dim 2
            inputs_reshape = self.input_projector(inputs_embed)
        else:
            # keep (Batch, ...) dim (0), flatten starting from dim 1
            inputs_embed = inputs_embed.view(inputs_embed.shape[0], -1)
            # apply linear layer for dim 1
            inputs_embed = self.input_projector(inputs_embed)
            # unsqueeze the time step to dim 1
            inputs_reshape = inputs_embed.unsqueeze(dim=1)

        memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)
        next_memory = self.attend_over_memory(memory_plus_input)

        # cut out the concatenated input vectors from the original memory slots
        n = inputs_reshape.shape[1]
        next_memory = next_memory[:, :-n, :]

        if self.gate_style == 'unit' or self.gate_style == 'memory':
            # these gates are sigmoid-applied ones for equation 7
            input_gate, forget_gate = self.create_gates(inputs_reshape, memory)
            # equation 7 calculation
            next_memory = input_gate * torch.tanh(next_memory)
            next_memory += forget_gate * memory

        output = next_memory.view(next_memory.shape[0], -1)
        return output, next_memory

    def forward(self, inputs, memory=None):
        # inputs.shape = NTC
        # Initialize memory if none available
        memory = self.initial_state(batch_size=inputs.shape[0]).to(inputs.device) if memory is None else memory
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        memory = self.repackage_hidden(memory)

        # for loop implementation of (entire) recurrent forward pass of the model
        # inputs is batch first [batch, seq], and output logit per step is [batch, vocab]
        # so the concatenated logits are [seq * batch, vocab]

        # targets are flattened [seq, batch] => [seq * batch], so the dimension is correct

        logits = []
        # shape[1] is seq_lenth T
        for idx_step in range(inputs.shape[1]):
            logit, memory = self.forward_step(inputs[:, idx_step], memory)
            logits.append(logit)
        # concat the output from list(seq_length) of [batch, vocab] to [seq * batch, vocab]
        logits = torch.stack(logits, 1)

        return logits, memory


class RelationalMemoryCell(RelationalMemory):
    """This class quacks like a GRUCell."""
    # noinspection PyMethodOverriding
    def forward(self, input, hx=None):
        if hx is None:
            memory = self.initial_state(batch_size=input.shape[0]).to(input.device)
        else:
            assert hx.shape[0] == input.shape[0]
            memory = hx.reshape(input.shape[0], self.mem_slots, self.head_size * self.num_heads)
        _, next_memory = self.forward_step(input, memory)
        # output is just the `next_memory` reshaped
        return next_memory.reshape(input.shape[0], self.mem_slots * self.head_size * self.num_heads)


class InteractingRelationalMemoryCell(RelationalMemory):
    """IRMC"""
    def __init__(self, input_size, mem_slots, head_size, num_heads=1, num_blocks=1, forget_bias=1.,
                 input_bias=0., gate_style='unit', attention_mlp_layers=2, key_size=None,
                 message_size=None):
        # We set _num_input_slots to 2
        super(InteractingRelationalMemoryCell, self).__init__(input_size=input_size, mem_slots=mem_slots,
                                                              head_size=head_size, num_heads=num_heads,
                                                              num_blocks=num_blocks, forget_bias=forget_bias,
                                                              input_bias=input_bias, gate_style=gate_style,
                                                              attention_mlp_layers=attention_mlp_layers,
                                                              key_size=key_size, _num_input_slots=2)
        # If not provided, message size defaults to memory size
        self.message_size = message_size if message_size is not None else self.mem_size
        # Build projector to convert memory to message and vice versa
        self.memory_to_message_projector = nn.Sequential(nn.Linear(self.mem_slots * self.mem_size, self.message_size),
                                                         nn.ReLU())
        self.message_projector = nn.Linear(self.message_size, self.mem_size)
        # Replace the input_gate_projector, which is now meant to consume the input and the message.
        self.input_gate_projector = nn.Linear(2 * self.mem_size, self.num_gates)

    # noinspection PyMethodOverriding
    def forward_step(self, inputs, message, memory, treat_input_as_matrix=False):
        inputs_embed = inputs

        if treat_input_as_matrix:
            # keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2
            inputs_embed = inputs_embed.view(inputs_embed.shape[0], inputs_embed.shape[1], -1)
            # apply linear layer for dim 2
            inputs_reshape = self.input_projector(inputs_embed)
        else:
            # keep (Batch, ...) dim (0), flatten starting from dim 1
            inputs_embed = inputs_embed.view(inputs_embed.shape[0], -1)
            # apply linear layer for dim 1
            inputs_embed = self.input_projector(inputs_embed)
            # unsqueeze the time step to dim 1
            inputs_reshape = inputs_embed.unsqueeze(dim=1)

        # Project message and reshape to fit in memory
        message_reshape = self.message_projector(message)[:, None, :]

        memory_plus_input_and_message = torch.cat([memory, inputs_reshape, message_reshape], dim=1)
        next_memory = self.attend_over_memory(memory_plus_input_and_message)

        # cut out the concatenated input vectors and message vector from the original memory slots
        n = inputs_reshape.shape[1] + message_reshape.shape[1]
        next_memory = next_memory[:, :-n, :]

        if self.gate_style == 'unit' or self.gate_style == 'memory':
            # these gates are sigmoid-applied ones for equation 7
            input_gate, forget_gate = self.create_gates(torch.cat([inputs_reshape, message_reshape], dim=-1), memory)
            # equation 7 calculation
            next_memory = input_gate * torch.tanh(next_memory)
            next_memory += forget_gate * memory

        # make a new message by looking at the memory
        next_message = self.memory_to_message_projector(next_memory.reshape(next_memory.shape[0], -1))
        return next_message, next_memory

    def initial_message(self, batch_size):
        return torch.zeros(batch_size, self.message_size)

    def forward(self, input, message=None, memory=None):
        if message is None:
            assert memory is None
            message = self.initial_message(batch_size=input.shape[0]).to(input.device)
        else:
            assert tuple(message.shape) == (input.shape[0], self.message_size)
        if memory is None:
            memory = self.initial_state(batch_size=input.shape[0]).to(input.device)
        else:
            # Unpack memory
            memory = memory.reshape(input.shape[0], self.mem_slots, self.mem_size)
        next_message, next_memory = self.forward_step(input, message, memory)
        # Pack memory
        next_memory = next_memory.reshape(input.shape[0], self.mem_slots * self.mem_size)
        return next_message, next_memory


def test_rnn():
    input_size = 44
    seq_length = 10
    batch_size = 32
    model: RelationalMemory = RelationalMemory(mem_slots=2, head_size=128, input_size=input_size,
                                               num_heads=4, attention_mlp_layers=3, key_size=64)
    model_memory = model.initial_state(batch_size=batch_size)

    # random input
    random_input = torch.randn((32, seq_length, input_size))
    # random targets
    random_targets = torch.randn((32, seq_length, input_size))

    # take a one step forward
    logit, next_memory = model(random_input, model_memory)
    print(logit.shape)
    print(next_memory.shape)


def test_rnn_cell():
    input_size = 44
    seq_length = 10
    batch_size = 32
    model: RelationalMemory = RelationalMemoryCell(mem_slots=2, head_size=128, input_size=input_size,
                                                   num_heads=4, attention_mlp_layers=3, key_size=64)
    input = torch.randn(4, input_size)
    hx = model(input)
    hx = model(input, hx)
    hx = model(input, hx)


def test_irnn_cell():
    input_size = 44
    seq_length = 10
    batch_size = 32
    model: RelationalMemory = InteractingRelationalMemoryCell(mem_slots=2, head_size=128, input_size=input_size,
                                                              num_heads=4, attention_mlp_layers=3, key_size=64,
                                                              message_size=256)
    input = torch.randn(4, input_size)
    hx, cx = model(input)
    hx, cx = model(input, hx, cx)
    hx, cx = model(input, hx, cx)


if __name__ == '__main__':
    test_irnn_cell()