import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from mawm.models import utils
from mawm.models.rmc import RelationalMemoryCell, InteractingRelationalMemoryCell

from einops import rearrange
from addict import Dict

# Legend:
#   n: batch
#   m: cell index
#   l: also cell index
#   a: agent index
#   h: hidden state index
#   c: also hidden or cell state index
#   s: positional encoding / cell embedding index
#   k: head index
#   v: value index
#   d: attention embedding index
#   t: time index


class TopoInputAttention(nn.Module):
    input_size: int
    cell_hidden_size: int
    cell_embedding_size: int
    num_heads: int
    embedding_size: int
    locality_measure: str
    topo_sparsity_topk: int
    fractional_noise: float
    zonal_kernel_truncation_parameter: float
    zonal_kernel_bandwidth: float
    zonal_kernel_straight_through: bool
    eps: float

    def __init__(self, input_size, cell_hidden_size, cell_embedding_size,
                 num_heads=1, embedding_size=32, locality_measure='truncated_zonal',
                 topo_sparsity_topk=-1, fractional_noise=0., zonal_kernel_truncation_parameter=0.,
                 zonal_kernel_bandwidth=1., zonal_kernel_straight_through=True,
                 eps=10e-5):
        super(TopoInputAttention, self).__init__()
        assert locality_measure in ['truncated_zonal', 'sparsemax']
        self.input_size = input_size
        self.cell_hidden_size = cell_hidden_size
        self.cell_embedding_size = cell_embedding_size
        self.num_heads = num_heads
        self.embedding_size = embedding_size
        self.locality_measure = locality_measure
        self.topo_sparsity_topk = topo_sparsity_topk
        self.fractional_noise = fractional_noise
        self.zonal_kernel_truncation_parameter = zonal_kernel_truncation_parameter
        self.zonal_kernel_bandwidth = zonal_kernel_bandwidth
        self.zonal_kernel_straight_through = zonal_kernel_straight_through
        self.eps = eps
        # Attention weights
        # noinspection PyArgumentList
        self.key_weights = nn.Parameter(
            nn.init.orthogonal_(torch.empty(cell_hidden_size + cell_embedding_size,
                                            num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.query_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(input_size + cell_embedding_size,
                                                                          num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.value_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(input_size + cell_embedding_size,
                                                                          num_heads, input_size)))
        # Setting bias = False preserves ReLU-ness
        self.value_to_input_projector = nn.Sequential(nn.Linear(num_heads * input_size, input_size),
                                                      nn.ReLU())
        self.local_vs_nonlocal_gate = nn.Sequential(nn.Linear(input_size * 2, 1),
                                                    nn.Sigmoid())

    def local_weights(self, cell_embeddings: torch.Tensor, positional_embeddings: torch.Tensor):
        # Compute attention weights based on topological locality
        # Shorthands and shapes
        #   (inp) inputs.shape = nai
        #   (pe) positional_embeddings.shape = nas
        #   (ce) cell_embeddings.shape = ms
        _n, _a, _s = positional_embeddings.shape
        _m, _ = cell_embeddings.shape
        assert cell_embeddings.shape == (_m, _s)
        # Compute the cosine similarity along s by recasting both pe and ce to mnas tensors.
        # The resulting tensor.shape = mna
        similarities = F.cosine_similarity(cell_embeddings[:, None, None, :].expand(_m, _n, _a, _s),
                                           positional_embeddings[None, :, :, :].expand(_m, _n, _a, _s),
                                           dim=-1)
        if self.locality_measure == 'truncated_zonal':
            weights = utils.zonal_kernel(similarities, bandwidth=self.zonal_kernel_bandwidth,
                                         truncate_at=self.zonal_kernel_truncation_parameter,
                                         straight_through=self.zonal_kernel_straight_through)
        elif self.locality_measure == 'sparsemax':
            weights = utils.sparsemax(similarities, self.topo_sparsity_topk, noise=self.fractional_noise, dim=-1)
        else:
            raise NotImplementedError
        return weights

    def non_local_weights(self, inputs: torch.Tensor, positional_embeddings: torch.Tensor,
                          cell_hidden_states: torch.Tensor, cell_embeddings: torch.Tensor):
        # (inp) inputs.shape: nai
        # (pe) positional_embeddings.shape: nas
        # (ce) cell_embeddings.shape: ms
        # (chs) cell_hidden_states.shape: mnh
        # The following implements MHDPA where the keys are the cells, queries are inp, and values a projection of inp.
        _m, _n, _h = cell_hidden_states.shape
        _, _s = cell_embeddings.shape
        _, _a, _i = inputs.shape
        assert positional_embeddings.shape == (_n, _a, _s)
        # Now we want the attn mech to be aware of the positions if we want it to learn that talk between far away
        # positions is not a good thing.
        # Concatenate ce with chs after expanding ce along the batches.
        cell_states = torch.cat([cell_hidden_states, cell_embeddings[:, None, :].expand(_m, _n, _s)], dim=-1)
        # Concatenate inp with pe.
        inputs_with_positions = torch.cat([inputs, positional_embeddings], dim=-1)
        # Use cell_states to compute keys
        keys = torch.einsum('mnh,hkd->nmkd', [cell_states, self.key_weights])
        # Use inp to compute queries and associated values
        queries = torch.einsum('nai,ikd->nakd', [inputs_with_positions, self.query_weights])
        values = torch.einsum('nai,ikv->nakv', [inputs_with_positions, self.value_weights])
        # Compute attention weights, and normalize along a (agents). Why? Because we want to model which cell (m)
        # gets which agent (a).
        weights = torch.einsum('nmkd,nakd->mnak', [keys, queries])
        weights = torch.softmax(weights / math.sqrt(self.embedding_size), dim=2)
        return weights, values

    def forward(self, inputs: torch.Tensor, cell_hidden_states: torch.Tensor, cell_embeddings: torch.Tensor,
                positional_embeddings: torch.Tensor):

        # local_weights.shape = mna, with normalization implied along a. But if a zonal kernel is used,
        # don't expect weights to sum up to 1 along a.
        local_weights = self.local_weights(cell_embeddings=cell_embeddings,
                                           positional_embeddings=positional_embeddings)
        local_outputs = torch.einsum('mna,nai->mni', [local_weights, inputs])
        # non_local_weights.shape = mnak,
        # values.shape = nakv
        non_local_weights, values = self.non_local_weights(inputs=inputs,
                                                           positional_embeddings=positional_embeddings,
                                                           cell_hidden_states=cell_hidden_states,
                                                           cell_embeddings=cell_embeddings)
        # Modulate non-local weights to obtain a tensor weights.shape = mnak
        weights = local_weights[:, :, :, None] * non_local_weights
        # Compute attention output from weights and values, and project to input size
        non_local_outputs = torch.einsum('mnak,nakv->mnkv', [weights, values])
        non_local_outputs = self.value_to_input_projector(rearrange(non_local_outputs, 'm n k v -> m n (k v)'))
        # Output is a gated sum. Note that the gate does not destroy locality, because even the `non_local_outputs`
        # is local in the sense that its weighted with local weights.
        gate_values = self.local_vs_nonlocal_gate(torch.cat([local_outputs, non_local_outputs], dim=-1))
        # noinspection PyTypeChecker
        outputs = gate_values * local_outputs + (1 - gate_values) * non_local_outputs
        return outputs


class TopoIntercellAttention(nn.Module):
    cell_hidden_size: int
    cell_embedding_size: int
    num_heads: int
    embedding_size: int
    locality_measure: str
    topo_sparsity_topk: int
    fractional_noise: float
    zonal_kernel_truncation_parameter: float
    zonal_kernel_bandwidth: float
    zonal_kernel_straight_through: bool
    eps: float

    def __init__(self, cell_hidden_size, cell_embedding_size, num_heads=1, embedding_size=32,
                 locality_measure='truncated_zonal', topo_sparsity_topk=-1, fractional_noise=0.,
                 zonal_kernel_truncation_parameter=0., zonal_kernel_bandwidth=1.,
                 zonal_kernel_straight_through=True, eps=10e-5):
        super(TopoIntercellAttention, self).__init__()
        assert locality_measure in ['truncated_zonal', 'sparsemax']
        self.cell_hidden_size = cell_hidden_size
        self.cell_embedding_size = cell_embedding_size
        self.num_heads = num_heads
        self.embedding_size = embedding_size
        self.locality_measure = locality_measure
        self.topo_sparsity_topk = topo_sparsity_topk
        self.fractional_noise = fractional_noise
        self.zonal_kernel_truncation_parameter = zonal_kernel_truncation_parameter
        self.zonal_kernel_bandwidth = zonal_kernel_bandwidth
        self.zonal_kernel_straight_through = zonal_kernel_straight_through
        self.eps = eps
        # Attention parameters
        # noinspection PyArgumentList
        self.key_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(cell_hidden_size + cell_embedding_size,
                                                                        num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.query_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(cell_hidden_size + cell_embedding_size,
                                                                          num_heads, embedding_size)))
        # noinspection PyArgumentList
        self.value_weights = nn.Parameter(nn.init.orthogonal_(torch.empty(cell_hidden_size + cell_embedding_size,
                                                                          num_heads, cell_hidden_size)))
        self.value_to_hidden_projector = nn.Sequential(nn.Linear(cell_hidden_size * num_heads, cell_hidden_size),
                                                       nn.Tanh())
        self.local_vs_nonlocal_gate = nn.Sequential(nn.Linear(cell_hidden_size * 2, 1),
                                                    nn.Sigmoid())

    def local_weights(self, cell_embeddings: torch.Tensor):
        # (ce) cell_embeddings.shape: ms, ls
        # similarities.shape = ml
        similarities = F.cosine_similarity(cell_embeddings[None, :, :], cell_embeddings[:, None, :], dim=-1)
        if self.locality_measure == 'truncated_zonal':
            weights = utils.zonal_kernel(similarities, bandwidth=self.zonal_kernel_bandwidth,
                                         truncate_at=self.zonal_kernel_truncation_parameter,
                                         straight_through=self.zonal_kernel_straight_through)
        elif self.locality_measure == 'sparsemax':
            weights = utils.sparsemax(similarities, self.topo_sparsity_topk, noise=self.fractional_noise, dim=-1)
        else:
            raise NotImplementedError
        return weights

    def non_local_weights(self, cell_hidden_states: torch.Tensor, cell_embeddings: torch.Tensor):
        _m, _n, _h = cell_hidden_states.shape
        _, _s = cell_embeddings.shape
        assert cell_embeddings.shape == (_m, _s)
        cell_states = torch.cat([cell_hidden_states, cell_embeddings[:, None, :].expand(_m, _n, _s)], dim=-1)
        keys = torch.einsum('mnh,hkd->mnkd', [cell_states, self.key_weights])
        queries = torch.einsum('lnh,hkd->lnkd', [cell_states, self.query_weights])
        values = torch.einsum('lnh,hkv->lnkv', [cell_states, self.value_weights])
        # Compute weights
        weights = torch.einsum('mnkd,lnkd->mnlk', [keys, queries])
        weights = torch.softmax(weights / math.sqrt(self.embedding_size), dim=2)
        return weights, values

    def forward(self, cell_hidden_states: torch.Tensor, cell_embeddings: torch.Tensor):
        # cell_hidden_states.shape: mnh
        # local_weights.shape: ml
        local_weights = self.local_weights(cell_embeddings=cell_embeddings)
        local_outputs = torch.einsum('lnh,ml->mnh', [cell_hidden_states, local_weights])
        # non_local_weights.shape: mnlk
        non_local_weights, values = self.non_local_weights(cell_hidden_states=cell_hidden_states,
                                                           cell_embeddings=cell_embeddings)
        # Modulate non-local weights with local weights (ml) to obtain a global weight tensor of shape mnlk
        weights = local_weights[:, None, :, None] * non_local_weights
        # Compute non-local outputs
        non_local_outputs = torch.einsum('mnlk,lnkv->mnkv', [weights, values])
        # Rearrange and project
        non_local_outputs = self.value_to_hidden_projector(rearrange(non_local_outputs, 'm n k v -> m n (k v)'))
        # Gate local against non-local
        gate_values = self.local_vs_nonlocal_gate(torch.cat([local_outputs, non_local_outputs], dim=-1))
        # noinspection PyTypeChecker
        outputs = gate_values * local_outputs + (1 - gate_values) * non_local_outputs
        # Done
        return outputs


class TopoOutputAttention(nn.Module):
    locality_measure: str
    topo_sparsity_topk: int
    fractional_noise: float
    zonal_kernel_truncation_parameter: float
    zonal_kernel_bandwidth: float
    zonal_kernel_straight_through: bool
    eps: float

    def __init__(self, locality_measure='truncated_zonal',
                 topo_sparsity_topk=-1, fractional_noise=0., zonal_kernel_truncation_parameter=0.,
                 zonal_kernel_bandwidth=1., zonal_kernel_straight_through=True,
                 eps=10e-5):
        super(TopoOutputAttention, self).__init__()
        assert locality_measure in ['truncated_zonal', 'sparsemax']
        self.locality_measure = locality_measure
        self.topo_sparsity_topk = topo_sparsity_topk
        self.fractional_noise = fractional_noise
        self.zonal_kernel_truncation_parameter = zonal_kernel_truncation_parameter
        self.zonal_kernel_bandwidth = zonal_kernel_bandwidth
        self.zonal_kernel_straight_through = zonal_kernel_straight_through
        self.eps = eps

    def forward(self, hidden: torch.Tensor, positional_embeddings: torch.Tensor,
                cell_embeddings: torch.Tensor):
        # hidden.shape: mnh
        # positional_embeddings.shape: nas
        # cell_embeddings.shape: ms
        _n, _a, _s = positional_embeddings.shape
        _m, _ = cell_embeddings.shape
        _, _, _h = hidden.shape
        assert cell_embeddings.shape == (_m, _s)
        assert hidden.shape == (_m, _n, _h)
        # similarities.shape: mnas
        similarities = F.cosine_similarity(positional_embeddings[None, :, :, :].expand(_m, _n, _a, _s),
                                           cell_embeddings[:, None, None, :].expand(_m, _n, _a, _s),
                                           dim=-1)
        # weights.shape: mna
        if self.locality_measure == 'truncated_zonal':
            weights = utils.zonal_kernel(similarities, bandwidth=self.zonal_kernel_bandwidth,
                                         truncate_at=self.zonal_kernel_truncation_parameter,
                                         straight_through=self.zonal_kernel_straight_through)
        elif self.locality_measure == 'sparsemax':
            weights = utils.sparsemax(similarities, self.topo_sparsity_topk, noise=self.fractional_noise, dim=-1)
        else:
            raise NotImplementedError
        # output.shape = nah
        output = torch.einsum('mna,mnh->nah', [weights, hidden])
        return output


class SpaceCells(nn.Module):
    num_cells: int
    cell_hidden_size: int
    embedding_size: int

    def __init__(self, cells: nn.ModuleList, input_attn: nn.Module, intercell_attn: nn.Module, output_attn: nn.Module,
                 cell_hidden_size: int, embedding_size: int, cell_state_sizes: list = None,
                 cell_embedding_init: str = None):
        super(SpaceCells, self).__init__()
        self.cells = nn.ModuleList(list(cells)) if isinstance(cells, (list, tuple)) else cells
        assert isinstance(self.cells, nn.ModuleList)
        self.input_attn = input_attn
        self.intercell_attn = intercell_attn
        self.output_attn = output_attn
        # Attries
        self.num_cells = len(cells)
        self.cell_hidden_size = cell_hidden_size
        self.embedding_size = embedding_size
        self.cell_state_sizes = [] if cell_state_sizes is None else list(cell_state_sizes)
        self.cell_embedding_init = cell_embedding_init
        # Params
        if self.cell_embedding_init in [None, 'orthogonal']:
            # noinspection PyArgumentList
            self.cell_embeddings = nn.Parameter(nn.init.orthogonal_(torch.empty(self.num_cells, self.embedding_size)))
        elif self.cell_embedding_init == 'uniform':
            # noinspection PyArgumentList
            self.cell_embeddings = nn.Parameter(nn.init.uniform_(torch.empty(self.num_cells, self.embedding_size),
                                                                 -1, 1))
        elif self.cell_embedding_init == 'normal':
            # noinspection PyArgumentList
            self.cell_embeddings = nn.Parameter(nn.init.normal_(torch.empty(self.num_cells, self.embedding_size)))
        else:
            raise NotImplementedError
        # Gate to decide whether to keep or discard the hiddens from intercell attention
        self.intercell_update_gate = nn.Sequential(nn.Linear(cell_hidden_size * 2, 1),
                                                   nn.Sigmoid())

    def initial_hidden(self, batch_size):
        return torch.zeros((self.num_cells, batch_size, self.cell_hidden_size))

    def initial_cell_states(self, batch_size):
        return [torch.zeros((self.num_cells, batch_size, cell_state_size))
                for cell_state_size in self.cell_state_sizes]

    def forward(self, inputs: torch.Tensor, positional_embeddings: torch.Tensor, hidden: torch.Tensor, *cell_states):
        # We have:
        #   inputs.shape: ni
        #   positional_embeddings.shape: ns
        #   hidden.shape: mnh
        #   cell_state.shape: mnc
        # `cell_states` will be a potentially empty tuple/list of `cell_state`s.
        # Compute input attention, resulting in a mni tensor
        cell_inputs = self.input_attn(inputs, hidden, self.cell_embeddings, positional_embeddings)
        # Compute hidden attention, resulting in a mnh tensor
        intercell_update = self.intercell_attn(hidden, self.cell_embeddings)
        # Gate'em. gate_values.shape = mn1
        gate_values = self.intercell_update_gate(torch.cat([intercell_update, hidden], dim=-1))
        updated_hidden = gate_values * intercell_update + (1 - gate_values) * hidden
        # Evaluate cells
        cell_outputs = []
        new_cell_states = []
        for cell_idx, cell in enumerate(self.cells):
            # Cell might be a GRU, in which case it returns only one output. Or it might be an LSTM (or something else)
            # with an internal cell state. We must account for both cases.
            cell_output_and_states = cell(cell_inputs[cell_idx], updated_hidden[cell_idx],
                                          *[state[cell_idx] for state in cell_states])
            if isinstance(cell_output_and_states, (list, tuple)):
                # A cell state is available (i.e. the cell could be an LSTM)
                cell_output, *cell_state = cell_output_and_states
            else:
                # A cell state is not available (i.e. the cell could be a GRU)
                assert isinstance(cell_output_and_states, torch.Tensor)
                cell_output = cell_output_and_states
                cell_state = None
            cell_outputs.append(cell_output)
            if cell_state is not None:
                new_cell_states.append(cell_state)
        # hidden.shape: mnh
        hidden = torch.stack(cell_outputs)
        if len(new_cell_states) > 0:
            new_cell_states = [torch.stack(states) for states in list(zip(*new_cell_states))]
        return [hidden, *new_cell_states]

    def query(self, hidden: torch.Tensor, positional_embeddings: torch.Tensor):
        return self.output_attn(hidden, positional_embeddings, self.cell_embeddings)


class SpaceRNN(nn.Module):
    def __init__(self, space_cells: SpaceCells):
        super(SpaceRNN, self).__init__()
        self.space_cells = space_cells
        self.batch_first = True

    def forward(self, inputs: torch.Tensor, positional_embeddings: torch.Tensor, hidden=None, cell_state=None):
        # inputs.shape: ntai
        # positional_embeddings.shape: ntas
        _n, _t, _a, _i = inputs.shape
        _, _, _, _s = positional_embeddings.shape
        assert positional_embeddings.shape == (_n, _t, _a, _s)
        # Get initial hidden
        if hidden is None:
            # noinspection PyUnresolvedReferences
            hidden = self.space_cells.initial_hidden(_n).to(inputs.device)
        if cell_state is None:
            # noinspection PyUnresolvedReferences
            cell_state = [_cell_state.to(inputs.device)
                          for _cell_state in self.space_cells.initial_cell_states(_n)]
        # Support for the case where hidden is actually a list of [hidden, *cell_state], since that is what this
        # function returns
        if isinstance(hidden, (list, tuple)):
            hidden = hidden[0]
            cell_state = cell_state[1:]
        hiddens = []
        # Loop over tiem
        for t in range(_t):
            hidden, *cell_state = self.space_cells(inputs[:, t], positional_embeddings[:, t], hidden, *cell_state)
            hiddens.append(hidden)
        # output.shape = mnth
        output = torch.stack(hiddens, dim=2)
        return output, [hidden, *cell_state]

    def query(self, hiddens: torch.Tensor, positional_embeddings: torch.Tensor):
        # hiddens.shape: mnth
        # positional_embeddings.shape: ntas
        _m, _n, _t, _h = hiddens.shape
        _, _, _a, _s = positional_embeddings.shape
        assert positional_embeddings.shape == (_n, _t, _a, _s)
        # Fold batch batch and time axis
        # noinspection PyUnresolvedReferences
        weighted_hiddens = self.space_cells.query(rearrange(hiddens, 'm n t h -> m (n t) h'),
                                                  rearrange(positional_embeddings, 'n t a s -> (n t) a s'))
        weighted_hiddens = rearrange(weighted_hiddens, '(n t) a h -> n t a h', n=_n, t=_t)
        # weighted_hiddens.shape = ntah
        return weighted_hiddens

    def flatten_hiddens(self, hiddens: torch.Tensor):
        # hiddens.shape = mnth
        # We transform it to nt(mh)
        return rearrange(hiddens, 'm n t h -> n t (m h)')


class SpaceGRUCells(SpaceCells):
    def __init__(self, input_size, num_cells, cell_hidden_size, cell_embedding_size, attn_embedding_size=32,
                 num_input_attn_heads=1, num_intercell_attn_heads=1, locality_measure='truncated_zonal',
                 topo_sparsity_topk=-1, fractional_noise=0., zonal_kernel_truncation_parameter=0.,
                 zonal_kernel_bandwidth=1., zonal_kernel_straight_through=True, cell_embedding_init=None,
                 eps=10e-5):
        input_attn = TopoInputAttention(input_size=input_size,
                                        cell_hidden_size=cell_hidden_size,
                                        cell_embedding_size=cell_embedding_size,
                                        num_heads=num_input_attn_heads,
                                        embedding_size=attn_embedding_size,
                                        locality_measure=locality_measure,
                                        topo_sparsity_topk=topo_sparsity_topk,
                                        fractional_noise=fractional_noise,
                                        zonal_kernel_truncation_parameter=zonal_kernel_truncation_parameter,
                                        zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                        zonal_kernel_straight_through=zonal_kernel_straight_through,
                                        eps=eps)
        intercell_attn = TopoIntercellAttention(cell_hidden_size=cell_hidden_size,
                                                cell_embedding_size=cell_embedding_size,
                                                num_heads=num_intercell_attn_heads,
                                                embedding_size=attn_embedding_size,
                                                locality_measure=locality_measure,
                                                topo_sparsity_topk=topo_sparsity_topk,
                                                fractional_noise=fractional_noise,
                                                zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                                zonal_kernel_straight_through=zonal_kernel_straight_through,
                                                eps=eps)
        output_attn = TopoOutputAttention(locality_measure=locality_measure,
                                          topo_sparsity_topk=topo_sparsity_topk,
                                          fractional_noise=fractional_noise,
                                          zonal_kernel_truncation_parameter=zonal_kernel_truncation_parameter,
                                          zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                          zonal_kernel_straight_through=zonal_kernel_straight_through,
                                          eps=eps)
        cells = nn.ModuleList([nn.GRUCell(input_size, cell_hidden_size) for _ in range(num_cells)])
        super(SpaceGRUCells, self).__init__(cells=cells, input_attn=input_attn, intercell_attn=intercell_attn,
                                            output_attn=output_attn, cell_hidden_size=cell_hidden_size,
                                            embedding_size=cell_embedding_size,
                                            cell_embedding_init=cell_embedding_init)


class SpaceRMCCells(SpaceCells):
    def __init__(self, input_size, num_cells, cell_num_memory_slots, cell_head_size, cell_num_heads,
                 cell_embedding_size, attn_embedding_size=32,
                 num_input_attn_heads=1, num_intercell_attn_heads=1, locality_measure='truncated_zonal',
                 topo_sparsity_topk=-1, fractional_noise=0., zonal_kernel_truncation_parameter=0.,
                 zonal_kernel_bandwidth=1., zonal_kernel_straight_through=True,
                 eps=10e-5, cell_num_attn_mlp_layers=3, cell_key_size=32, cell_message_size=None,
                 interacting_cells=False, cell_embedding_init=None, **_):
        if interacting_cells:
            cell_message_size = (cell_head_size * cell_num_heads) if cell_message_size is None else cell_message_size
            cell_hidden_size = cell_message_size
        else:
            cell_hidden_size = cell_num_memory_slots * cell_head_size * cell_num_heads
        input_attn = TopoInputAttention(input_size=input_size,
                                        cell_hidden_size=cell_hidden_size,
                                        cell_embedding_size=cell_embedding_size,
                                        num_heads=num_input_attn_heads,
                                        embedding_size=attn_embedding_size,
                                        locality_measure=locality_measure,
                                        topo_sparsity_topk=topo_sparsity_topk,
                                        fractional_noise=fractional_noise,
                                        zonal_kernel_truncation_parameter=zonal_kernel_truncation_parameter,
                                        zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                        zonal_kernel_straight_through=zonal_kernel_straight_through,
                                        eps=eps)
        intercell_attn = TopoIntercellAttention(cell_hidden_size=cell_hidden_size,
                                                cell_embedding_size=cell_embedding_size,
                                                num_heads=num_intercell_attn_heads,
                                                embedding_size=attn_embedding_size,
                                                locality_measure=locality_measure,
                                                topo_sparsity_topk=topo_sparsity_topk,
                                                fractional_noise=fractional_noise,
                                                zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                                zonal_kernel_straight_through=zonal_kernel_straight_through,
                                                eps=eps)
        output_attn = TopoOutputAttention(locality_measure=locality_measure,
                                          topo_sparsity_topk=topo_sparsity_topk,
                                          fractional_noise=fractional_noise,
                                          zonal_kernel_truncation_parameter=zonal_kernel_truncation_parameter,
                                          zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                          zonal_kernel_straight_through=zonal_kernel_straight_through,
                                          eps=eps)
        if interacting_cells:
            cells = nn.ModuleList([InteractingRelationalMemoryCell(input_size, mem_slots=cell_num_memory_slots,
                                                                   head_size=cell_head_size, num_heads=cell_num_heads,
                                                                   attention_mlp_layers=cell_num_attn_mlp_layers,
                                                                   key_size=cell_key_size,
                                                                   message_size=cell_message_size)
                                   for _ in range(num_cells)])
            cell_state_sizes = [cell_num_memory_slots * cell_head_size * cell_num_heads]
        else:
            cells = nn.ModuleList([RelationalMemoryCell(input_size, mem_slots=cell_num_memory_slots,
                                                        head_size=cell_head_size, num_heads=cell_num_heads,
                                                        attention_mlp_layers=cell_num_attn_mlp_layers,
                                                        key_size=cell_key_size)
                                   for _ in range(num_cells)])
            cell_state_sizes = None
        super(SpaceRMCCells, self).__init__(cells=cells, input_attn=input_attn, intercell_attn=intercell_attn,
                                            output_attn=output_attn, cell_hidden_size=cell_hidden_size,
                                            embedding_size=cell_embedding_size, cell_state_sizes=cell_state_sizes,
                                            cell_embedding_init=cell_embedding_init)

    def initial_cell_states(self, batch_size):
        if not isinstance(self.cells[0], InteractingRelationalMemoryCell):
            return super(SpaceRMCCells, self).initial_cell_states(batch_size)
        # The initial memory is not all zeros; poll the cells
        # noinspection PyTypeChecker
        initial_memories = torch.stack([cell.initial_state(batch_size, packed=True) for cell in self.cells])
        # We put it in a list because it's the only cell state
        return [initial_memories]

    def initial_hidden(self, batch_size):
        if not isinstance(self.cells[0], InteractingRelationalMemoryCell):
            return super(SpaceRMCCells, self).initial_hidden(batch_size)
        # These are prolly zeros, but we may want to learn them in the future, in which case they're parameters
        # managed by the cells.
        # noinspection PyTypeChecker
        initial_messages = torch.stack([cell.initial_message(batch_size) for cell in self.cells])
        return initial_messages


class SpaceLSTMCells(SpaceCells):
    def __init__(self, input_size, num_cells, cell_hidden_size, cell_embedding_size, attn_embedding_size=32,
                 num_input_attn_heads=1, num_intercell_attn_heads=1, locality_measure='truncated_zonal',
                 topo_sparsity_topk=-1, fractional_noise=0., zonal_kernel_truncation_parameter=0.,
                 zonal_kernel_bandwidth=1., zonal_kernel_straight_through=True, cell_embedding_init=None,
                 eps=10e-5):
        input_attn = TopoInputAttention(input_size=input_size,
                                        cell_hidden_size=cell_hidden_size,
                                        cell_embedding_size=cell_embedding_size,
                                        num_heads=num_input_attn_heads,
                                        embedding_size=attn_embedding_size,
                                        locality_measure=locality_measure,
                                        topo_sparsity_topk=topo_sparsity_topk,
                                        fractional_noise=fractional_noise,
                                        zonal_kernel_truncation_parameter=zonal_kernel_truncation_parameter,
                                        zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                        zonal_kernel_straight_through=zonal_kernel_straight_through,
                                        eps=eps)
        intercell_attn = TopoIntercellAttention(cell_hidden_size=cell_hidden_size,
                                                cell_embedding_size=cell_embedding_size,
                                                num_heads=num_intercell_attn_heads,
                                                embedding_size=attn_embedding_size,
                                                locality_measure=locality_measure,
                                                topo_sparsity_topk=topo_sparsity_topk,
                                                fractional_noise=fractional_noise,
                                                zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                                zonal_kernel_straight_through=zonal_kernel_straight_through,
                                                eps=eps)
        output_attn = TopoOutputAttention(locality_measure=locality_measure,
                                          topo_sparsity_topk=topo_sparsity_topk,
                                          fractional_noise=fractional_noise,
                                          zonal_kernel_truncation_parameter=zonal_kernel_truncation_parameter,
                                          zonal_kernel_bandwidth=zonal_kernel_bandwidth,
                                          zonal_kernel_straight_through=zonal_kernel_straight_through,
                                          eps=eps)
        cells = nn.ModuleList([utils.LSTMCell(input_size, cell_hidden_size) for _ in range(num_cells)])
        super(SpaceLSTMCells, self).__init__(cells=cells, input_attn=input_attn, intercell_attn=intercell_attn,
                                             output_attn=output_attn, cell_hidden_size=cell_hidden_size,
                                             embedding_size=cell_embedding_size, cell_state_sizes=[cell_hidden_size],
                                             cell_embedding_init=cell_embedding_init)


class SpaceGRU(SpaceRNN):
    def __init__(self, *args, **kwargs):
        space_cells = SpaceGRUCells(*args, **kwargs)
        super(SpaceGRU, self).__init__(space_cells)


class SpaceRMC(SpaceRNN):
    def __init__(self, *args, **kwargs):
        space_cells = SpaceRMCCells(*args, **kwargs)
        super(SpaceRMC, self).__init__(space_cells)


class SpaceLSTM(SpaceRNN):
    def __init__(self, *args, **kwargs):
        space_cells = SpaceLSTMCells(*args, **kwargs)
        super(SpaceLSTM, self).__init__(space_cells)


KWARGSETS = Dict(
    DEFAULT={
        'cell_embedding_size': 16,
        'attn_embedding_size': 16,
        'num_input_attn_heads': 2,
        'num_intercell_attn_heads': 4,
        'locality_measure': 'truncated_zonal',
        'topo_sparsity_topk': -1,
        'fractional_noise': 0.,
        'zonal_kernel_truncation_parameter': 0.,
        'zonal_kernel_bandwidth': 1.,
        'zonal_kernel_straight_through': True
    }
)


if __name__ == '__main__':
    # srmc = SpaceGRU(input_size=144, num_cells=8, cell_hidden_size=128, cell_embedding_size=32,
    #                 attn_embedding_size=32, num_input_attn_heads=2, num_intercell_attn_heads=4,
    #                 locality_measure='truncated_zonal', topo_sparsity_topk=-1, fractional_noise=0.,
    #                 zonal_kernel_truncation_parameter=0., zonal_kernel_bandwidth=1.,
    #                 zonal_kernel_straight_through=True)

    srmc = SpaceLSTM(input_size=144, num_cells=8, cell_hidden_size=128, cell_embedding_size=32,
                     attn_embedding_size=32, num_input_attn_heads=2, num_intercell_attn_heads=4,
                     locality_measure='truncated_zonal', topo_sparsity_topk=-1, fractional_noise=0.,
                     zonal_kernel_truncation_parameter=0., zonal_kernel_bandwidth=1.,
                     zonal_kernel_straight_through=True)

    # srmc = SpaceRMC(input_size=144, num_cells=8, cell_num_memory_slots=2, cell_head_size=64, cell_num_heads=2,
    #                 cell_embedding_size=32, attn_embedding_size=32, num_input_attn_heads=2, num_intercell_attn_heads=4,
    #                 locality_measure='truncated_zonal', topo_sparsity_topk=-1, fractional_noise=0.,
    #                 zonal_kernel_truncation_parameter=0., zonal_kernel_bandwidth=1.,
    #                 zonal_kernel_straight_through=True, interacting_cells=True)

    N, T, A, I = 3, 16, 10, 144
    S = 32

    x = torch.randn(N, T, A, I)
    p = torch.rand(N, T, A, S)
    print(srmc.space_cells.cell_hidden_size)
    h = srmc(x, p)
    pass
