import math
import torch
import torch.nn as nn
import torch.nn.functional as F


from einops import rearrange
from .encoder import PositionalEncoding


class MAB(nn.Module):
    """Implementation from https://github.com/juho-lee/set_transformer/blob/master/modules.py"""

    def __init__(self, dim_Q, dim_K, dim_V, num_heads, capacity=None, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        if capacity is not None:
            self.fc_q = nn.Sequential(nn.Linear(dim_Q, capacity),
                                      nn.ReLU(),
                                      nn.Linear(capacity, dim_V))
            self.fc_k = nn.Sequential(nn.Linear(dim_K, capacity),
                                      nn.ReLU(),
                                      nn.Linear(capacity, dim_V))
            self.fc_v = nn.Sequential(nn.Linear(dim_K, capacity),
                                      nn.ReLU(),
                                      nn.Linear(capacity, dim_V))
            self.fc_o = nn.Sequential(nn.Linear(dim_V, capacity),
                                      nn.ReLU(),
                                      nn.Linear(capacity, dim_V))
        else:
            self.fc_q = nn.Linear(dim_Q, dim_V)
            self.fc_k = nn.Linear(dim_K, dim_V)
            self.fc_v = nn.Linear(dim_K, dim_V)
            self.fc_o = nn.Linear(dim_V, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O


class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, capacity=None, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln, capacity=capacity)

    def forward(self, X):
        return self.mab(X, X)


class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, capacity=None, ln=False):
        super(ISAB, self).__init__()
        # noinspection PyArgumentList
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, capacity=capacity)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln, capacity=capacity)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)


class EGSAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, capacity=None, ln=False):
        super(EGSAB, self).__init__()
        # noinspection PyArgumentList
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, capacity=capacity)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return H


class SpatialProjectionLayer(nn.Module):
    def __init__(self, in_features, out_channels, output_size, num_heads=1, positional_encoding_dim=32,
                 capacity=None):
        super(SpatialProjectionLayer, self).__init__()
        self.output_size = tuple(output_size)
        assert len(self.output_size) == 2
        self.positional_encoding_dim = positional_encoding_dim
        self.in_features = in_features
        self.out_channels = out_channels
        self.capacity = capacity
        # Build
        self.mab = MAB(self.positional_encoding_dim, in_features, out_channels, num_heads, ln=True,
                       capacity=capacity)
        self._buffer_positional_encoding()

    def _buffer_positional_encoding(self):
        positional_encoder = PositionalEncoding(position_dim=len(self.output_size),
                                                encoding_dim=self.positional_encoding_dim)
        # noinspection PyTypeChecker
        ii, jj = torch.meshgrid(*[torch.arange(self.output_size[0], dtype=torch.float),
                                  torch.arange(self.output_size[1], dtype=torch.float)])
        grid_positions = torch.stack([ii, jj], dim=-1).reshape(self.output_size[0] * self.output_size[1], 1, 1, 2)
        # grid_embeddings.shape = (hw)c
        grid_embeddings = positional_encoder(grid_positions).reshape(self.output_size[0] * self.output_size[1],
                                                                     self.positional_encoding_dim)
        self.register_buffer('grid_embeddings', grid_embeddings)

    def forward(self, x):
        # x.shape = ntc
        N, T, C = x.shape
        assert C == self.in_features
        # Add batch axis to buffer
        output = self.mab(self.grid_embeddings.repeat(N, 1, 1), x)
        output = rearrange(output, 'n (h w) c -> n c h w', h=self.output_size[0], w=self.output_size[1])
        return output


class EntityExtractionLayer(nn.Module):
    def __init__(self, in_features, out_features, num_entities, num_heads=1, capacity=None):
        super(EntityExtractionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_entities = num_entities
        self.egsab = EGSAB(in_features, out_features, num_inds=num_entities, num_heads=num_heads, ln=True,
                           capacity=capacity)

    def forward(self, x):
        N, C = x.shape
        assert C == self.in_features
        return self.egsab(x.reshape(N, 1, C))


class EntityProductionLayer(nn.Module):
    def __init__(self, in_features, out_features, num_entities, capacity=None):
        super(EntityProductionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_entities = num_entities
        self.capacity = capacity
        if capacity is None:
            self.projector = nn.Linear(in_features, out_features * num_entities)
        else:
            self.projector = nn.Sequential(nn.Linear(in_features, capacity),
                                           nn.ReLU(),
                                           nn.Linear(capacity, out_features * num_entities))

    def forward(self, x):
        N, C = x.shape
        assert C == self.in_features
        projected = rearrange(self.projector(x), 'n (e c) -> n e c', e=self.num_entities, c=self.out_features)
        return projected


class EntityPropagationLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads=1, entity_bottleneck=None, capacity=None):
        super(EntityPropagationLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.entity_bottleneck = entity_bottleneck
        if self.entity_bottleneck is not None:
            raise NotImplementedError
        self.sab = SAB(in_features, out_features, num_heads=num_heads, ln=True, capacity=capacity)

    def forward(self, x):
        return self.sab(x)


class EntityIntegrationLayer(nn.Module):
    def __init__(self, main_in_features, aux_in_features, out_features, num_heads=1, capacity=None):
        super(EntityIntegrationLayer, self).__init__()
        self.main_in_features = main_in_features
        self.aux_in_features = aux_in_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.capacity = capacity
        self.mab = MAB(main_in_features, aux_in_features, out_features, num_heads, capacity, ln=True)

    def forward(self, main_in, aux_in):
        return self.mab(main_in, aux_in)


if __name__ == '__main__':
    # eel = EntityExtractionLayer(5, 5, 8)
    # print(eel(torch.rand(2, 5)).shape)
    eil = EntityIntegrationLayer(32, 16, 48)
    print(eil(torch.rand(4, 10, 32), torch.rand(4, 20, 16)).shape)