import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Union

from .utils import PolarPad
from .attn import EntityExtractionLayer, EntityPropagationLayer, EntityIntegrationLayer, SpatialProjectionLayer


class ResidualDecoder(nn.Module):
    def __init__(self, output_size, out_channels, context_channels, capacity=128, repr_dim=128, polar=False):
        super(ResidualDecoder, self).__init__()
        self.output_size = output_size
        self.out_channels = out_channels
        self.context_channels = context_channels
        self.capacity = capacity
        self.repr_dim = repr_dim
        self.polar = polar
        if tuple(output_size) == (5, 5):
            self.conv1 = nn.Conv2d(repr_dim, capacity, kernel_size=1)
            self.conv2 = nn.ConvTranspose2d(capacity + context_channels, capacity, kernel_size=3)
            self.conv3 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv4 = nn.ConvTranspose2d(capacity, capacity, kernel_size=3)
            self.conv5 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv6 = nn.Conv2d(capacity, out_channels, kernel_size=3, padding=1)
        elif tuple(output_size) == (11, 11):
            self.conv1 = nn.ConvTranspose2d(repr_dim, capacity, kernel_size=5)
            self.conv2 = nn.ConvTranspose2d(capacity + context_channels, capacity, kernel_size=3)
            self.conv3 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv4 = nn.ConvTranspose2d(capacity, capacity, kernel_size=3)
            self.conv5 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv6 = nn.ConvTranspose2d(capacity, capacity, kernel_size=3)
            self.conv7 = nn.Conv2d(capacity, out_channels, kernel_size=3, padding=1)
        elif tuple(output_size) == (20, 20):
            self.conv1 = nn.Conv2d(repr_dim, capacity, kernel_size=1)
            self.conv2 = nn.ConvTranspose2d(capacity + context_channels, capacity, kernel_size=3, stride=2)
            self.conv3 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv4 = nn.ConvTranspose2d(capacity, capacity, kernel_size=3, stride=2)
            self.conv5 = nn.ConvTranspose2d(capacity, capacity, kernel_size=3, stride=2)
            self.conv6 = nn.ConvTranspose2d(capacity, capacity, kernel_size=6, stride=1)
            self.conv7 = nn.Conv2d(capacity, out_channels, kernel_size=3, padding=1)
        elif tuple(output_size) == (14, 36) and polar:
            self.conv1 = nn.Conv2d(repr_dim, capacity, kernel_size=1)
            self.conv2 = nn.ConvTranspose2d(capacity + context_channels, capacity, kernel_size=(3, 5))
            self.conv3 = nn.Sequential(PolarPad((3, 5)),
                                       nn.Conv2d(capacity, capacity, kernel_size=(3, 5)))
            self.conv4 = nn.ConvTranspose2d(capacity, capacity, kernel_size=(3, 5), stride=(2, 3))
            self.conv5 = nn.Sequential(PolarPad((3, 5)),
                                       nn.Conv2d(capacity, capacity, kernel_size=(3, 5)))
            self.conv6 = nn.ConvTranspose2d(capacity, capacity, kernel_size=(3, 5), stride=(2, 2))
            self.conv7 = nn.Sequential(PolarPad((3, 5), ((0, -1), (0, -1))),
                                       nn.Conv2d(capacity, out_channels, kernel_size=(3, 5)))
        else:
            raise NotImplementedError

    def forward(self, repr_: torch.Tensor, context: Union[torch.Tensor, type(None)] = None):
        polar = 'p' if self.polar else ''
        return getattr(self, f"_forward_{self.output_size[0]}x{self.output_size[1]}{polar}")(repr_, context)

    def _check_shapes(self, repr_: torch.Tensor, context: torch.Tensor):
        # repr_.shape = (NTC)
        # context.shape = (NTAC')
        N, T, RC = repr_.shape
        _, _, A, CC = context.shape
        assert tuple(context.shape) == (N, T, A, CC)
        return N, T, RC, A, CC

    def _forward_5x5(self, repr_: torch.Tensor, context: torch.Tensor):
        # repr_.shape = (NTC)
        # context.shape = (NTAC')
        N, T, RC, A, CC = self._check_shapes(repr_, context)
        x = F.relu(self.conv1(repr_.reshape(N * T, RC, 1, 1)))
        # x.shape = (NT)C11. Convert to (NT)AC11 by repeating along A
        x = x[:, None, :, :, :].repeat(1, A, 1, 1, 1)
        # context.shape = NTAC'. Convert to (NT)AC'11
        c = context.reshape(N * T, A, CC, 1, 1)
        # Now concatenate along the channel axis and reshape to (NTA)(C + C')11
        x = torch.cat([x, c], dim=2).view(N * T * A, self.capacity + CC, 1, 1)
        # Keep convolving, it's straight forward from here on
        preskip = x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x + preskip
        preskip = x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = x + preskip
        x = self.conv6(x)
        # Reshape x back to (N, T, A, C, h, w)
        return x.view(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])

    def _forward_11x11(self, repr_: torch.Tensor, context: torch.Tensor):
        N, T, RC, A, CC = self._check_shapes(repr_, context)
        x = F.relu(self.conv1(repr_.reshape(N * T, RC, 1, 1)))
        # x.shape = (NT)C11. Convert to (NT)AC11 by repeating along A
        x = x[:, None, :, :, :].repeat(1, A, 1, 1, 1)
        # context.shape = NTAC'. Convert to (NT)AC'11
        c = context.reshape(N * T, A, CC, 1, 1).expand(N * T, A, CC, x.shape[-2], x.shape[-1])
        # Now concatenate along the channel axis and reshape to (NTA)(C + C')11
        x = torch.cat([x, c], dim=2).view(N * T * A, self.capacity + CC, x.shape[-2], x.shape[-1])
        # Up we go
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x)) + x
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x)) + x
        x = F.relu(self.conv6(x))
        x = self.conv7(x)
        # Reshape back to NTAChw
        return x.reshape(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])

    def _forward_20x20(self, repr_: torch.Tensor, context: torch.Tensor):
        raise NotImplementedError

    def _forward_14x36p(self, repr_: torch.Tensor, context: torch.Tensor):
        # repr_.shape = (NTC)
        # context.shape = (NTAC')
        N, T, RC, A, CC = self._check_shapes(repr_, context)
        x = F.relu(self.conv1(repr_.reshape(N * T, RC, 1, 1)))
        # x.shape = (NT)C11. Convert to (NT)AC11 by repeating along A
        x = x[:, None, :, :, :].repeat(1, A, 1, 1, 1)
        # context.shape = NTAC'. Convert to (NT)AC'11
        c = context.reshape(N * T, A, CC, 1, 1)
        x = torch.cat([x, c], dim=2).view(N * T * A, self.capacity + CC, 1, 1)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x)) + x
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x)) + x
        x = F.relu(self.conv6(x))
        x = self.conv7(x)
        # Reshape x back to (N, T, A, C, h, w)
        return x.view(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])


class ContextlessResidualDecoder(ResidualDecoder):
    def __init__(self, output_size, out_channels, capacity=128, repr_dim=128, polar=False):
        super(ContextlessResidualDecoder, self).__init__(output_size=output_size, out_channels=out_channels,
                                                         capacity=capacity, repr_dim=repr_dim, context_channels=0,
                                                         polar=polar)

    def _forward_5x5(self, repr_: torch.Tensor, context: type(None) = None):
        N, T, A, RC = repr_.shape
        x = F.relu(self.conv1(repr_.reshape(N * T * A, RC, 1, 1)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x)) + x
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x)) + x
        x = self.conv6(x)
        # Reshape x back to NTAChw
        return x.reshape(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])

    def _forward_11x11(self, repr_: torch.Tensor, context: type(None) = None):
        N, T, A, RC = repr_.shape
        x = F.relu(self.conv1(repr_.reshape(N * T * A, RC, 1, 1)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x)) + x
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x)) + x
        x = F.relu(self.conv6(x))
        x = self.conv7(x)
        # Reshape x back to NTAChw
        return x.reshape(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])

    def _forward_20x20(self, repr_: torch.Tensor, context: type(None) = None):
        N, T, A, RC = repr_.shape
        x = F.relu(self.conv1(repr_.reshape(N * T * A, RC, 1, 1)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x)) + x
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.conv7(x)
        # Reshape x back to (N, T, A, C, h, w)
        return x.reshape(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])

    def _forward_14x36p(self, repr_: torch.Tensor, context: type(None) = None):
        N, T, A, RC = repr_.shape
        x = F.relu(self.conv1(repr_.reshape(N * T * A, RC, 1, 1)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x)) + x
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x)) + x
        x = F.relu(self.conv6(x))
        x = self.conv7(x)
        # Reshape x back to (N, T, A, C, h, w)
        return x.reshape(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])


class TransformerDecoder(nn.Module):
    def __init__(self, output_size, out_channels, context_channels, capacity=128, repr_dim=128,
                 num_entities=20, num_heads=4, num_propagators=3, extractor_capacity=None,
                 integrator_capacity=None, propagator_capacity=None,
                 spatial_projector_capacity=None, spatial_positional_encoding_dim=32):
        super(TransformerDecoder, self).__init__()
        assert len(output_size) == 2
        self.output_size = output_size
        self.out_channels = out_channels
        self.context_channels = context_channels
        self.capacity = capacity
        self.repr_dim = repr_dim
        self.num_entities = num_entities
        self.num_heads = num_heads
        self.num_propagators = num_propagators
        self.extractor_capacity = extractor_capacity
        self.propagator_capacity = propagator_capacity
        self.spatial_projector_capacity = spatial_projector_capacity
        self.spatial_positional_encoding_dim = spatial_positional_encoding_dim
        # Build modules
        self.extractor = EntityExtractionLayer(repr_dim, capacity, num_entities, num_heads=num_heads,
                                               capacity=extractor_capacity)
        if context_channels:
            self.context_integrator = EntityIntegrationLayer(capacity, context_channels, capacity,
                                                             num_heads=num_heads, capacity=integrator_capacity)
        else:
            self.context_integrator = None
        self.attn = nn.Sequential(*[EntityPropagationLayer(capacity, capacity, num_heads, propagator_capacity)
                                    for _ in range(num_propagators)])
        self.spatial_projector = SpatialProjectionLayer(capacity, capacity, output_size,
                                                        num_heads=num_heads,
                                                        positional_encoding_dim=spatial_positional_encoding_dim,
                                                        capacity=spatial_projector_capacity)
        self.conv_reducer = nn.Conv2d(capacity, out_channels, kernel_size=1)

    def forward(self, repr_: torch.Tensor, context: torch.Tensor):
        # TODO
        raise NotImplementedError


class ContextlessTransformerDecoder(nn.Module):
    def __init__(self, output_size, out_channels, capacity=128, repr_dim=128, num_entities=20, num_heads=4,
                 num_propagators=3, extractor_capacity=None, propagator_capacity=None,
                 spatial_projector_capacity=None, spatial_positional_encoding_dim=32):
        super(ContextlessTransformerDecoder, self).__init__()
        assert len(output_size) == 2
        self.output_size = output_size
        self.out_channels = out_channels
        self.capacity = capacity
        self.repr_dim = repr_dim
        self.num_entities = num_entities
        self.num_heads = num_heads
        self.num_propagators = num_propagators
        self.extractor_capacity = extractor_capacity
        self.propagator_capacity = propagator_capacity
        self.spatial_projector_capacity = spatial_projector_capacity
        self.spatial_positional_encoding_dim = spatial_positional_encoding_dim
        # Build modules
        self.extractor = EntityExtractionLayer(repr_dim, capacity, num_entities, num_heads=num_heads,
                                               capacity=extractor_capacity)
        self.attn = nn.Sequential(*[EntityPropagationLayer(capacity, capacity, num_heads, propagator_capacity)
                                    for _ in range(num_propagators)])
        self.spatial_projector = SpatialProjectionLayer(capacity, capacity, output_size,
                                                        num_heads=num_heads,
                                                        positional_encoding_dim=spatial_positional_encoding_dim,
                                                        capacity=spatial_projector_capacity)
        self.conv_reducer = nn.Conv2d(capacity, out_channels, kernel_size=1)

    def forward(self, repr_: torch.Tensor, context: type(None) = None):
        N, T, A, RC = repr_.shape
        x = self.extractor(repr_.reshape(N * T * A, RC))
        x = self.attn(x)
        x = self.spatial_projector(x)
        x = self.conv_reducer(x)
        # Reshape x back to NTAChw
        return x.reshape(N, T, A, self.out_channels, self.output_size[0], self.output_size[1])


if __name__ == '__main__':
    # dec = ContextlessResidualDecoder((14, 36), 21, polar=True)
    # dec = ContextlessTransformerDecoder((14, 36), 21)
    dec = ContextlessResidualDecoder((11, 11), 1, 5)
    x = torch.randn(1, 12, 2, 128)
    c = torch.randn(1, 12, 2, 5)
    y = dec(x)
    print(y.shape)
