import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli

from typing import Union

from .utils import PolarPad


class TowerEncoder(nn.Module):
    def __init__(self, input_size, in_channels, context_channels, capacity=128, repr_dim=128, polar=False):
        super(TowerEncoder, self).__init__()
        self.input_size = input_size
        self.in_channels = in_channels
        self.context_channels = context_channels
        self.capacity = capacity
        self.repr_dim = repr_dim
        self.polar = polar
        if tuple(input_size) == (5, 5):
            self.conv1 = nn.Conv2d(in_channels, capacity, kernel_size=3)
            self.conv2 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(capacity, capacity, kernel_size=3)
            self.conv4 = nn.Conv2d(capacity + context_channels, capacity, kernel_size=1)
            self.conv5 = nn.Conv2d(capacity, capacity, kernel_size=1)
            self.conv6 = nn.Conv2d(capacity, capacity, kernel_size=1)
            self.conv7 = nn.Conv2d(capacity, repr_dim, kernel_size=1)
        elif tuple(input_size) == (11, 11):
            self.conv1 = nn.Conv2d(in_channels, capacity, kernel_size=5)
            self.conv2 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(capacity + context_channels, capacity, kernel_size=3)
            self.conv4 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv5 = nn.Conv2d(capacity, capacity, kernel_size=3)
            self.conv6 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv7 = nn.Conv2d(capacity, repr_dim, kernel_size=3)
        elif tuple(input_size) == (20, 20):
            self.conv1 = nn.Conv2d(in_channels, capacity, kernel_size=3, stride=2, padding=1)
            self.conv2 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(capacity + context_channels, capacity, kernel_size=3, stride=2, padding=1)
            self.conv4 = nn.Conv2d(capacity, capacity, kernel_size=3, padding=1)
            self.conv5 = nn.Conv2d(capacity, capacity, kernel_size=3, stride=2)
            self.conv6 = nn.Conv2d(capacity, capacity, kernel_size=1)
            self.conv7 = nn.Conv2d(capacity, repr_dim, kernel_size=2)
        elif tuple(input_size) == (14, 36) and polar:
            self.conv1 = nn.Sequential(PolarPad((3, 5)),
                                       nn.Conv2d(in_channels, capacity, kernel_size=(3, 5), stride=(2, 3)))
            self.conv2 = nn.Sequential(PolarPad((3, 5)),
                                       nn.Conv2d(capacity, capacity, kernel_size=(3, 5)))
            self.conv3 = nn.Sequential(PolarPad((1, 5)),
                                       nn.Conv2d(capacity + context_channels, capacity, kernel_size=(3, 5), stride=(2, 2)))
            self.conv4 = nn.Sequential(PolarPad((3, 5)),
                                       nn.Conv2d(capacity, capacity, kernel_size=(3, 5)))
            self.conv5 = nn.Sequential(PolarPad((1, 3)),
                                       nn.Conv2d(capacity, capacity, kernel_size=(3, 3), stride=(1, 2)))
            self.conv6 = nn.Sequential(PolarPad((1, 3)),
                                       nn.Conv2d(capacity, capacity, kernel_size=(1, 3)))
            self.conv7 = nn.Conv2d(capacity, repr_dim, kernel_size=(1, 3))
        else:
            raise NotImplementedError

    def forward(self, input: torch.Tensor, context: Union[torch.Tensor, type(None)] = None):
        assert list(input.shape[-2:]) == list(self.input_size)
        polar = 'p' if self.polar else ''
        return getattr(self, f'_forward_{self.input_size[0]}x{self.input_size[1]}{polar}')(input, context)

    def _check_shapes(self, input: torch.Tensor, context: torch.Tensor):
        # We expect input.shape to be NTACHW, and context.shape to be (NTAC')
        N, T, A, IC, H, W = input.shape
        _, _, _, CC = context.shape
        assert IC == self.in_channels
        assert CC == self.context_channels
        assert tuple(context.shape) == (N, T, A, CC)
        return N, T, A, IC, H, W, CC

    def _forward_5x5(self, input: torch.Tensor, context: torch.Tensor):
        N, T, A, IC, H, W, CC = self._check_shapes(input, context)
        x = input.reshape(N * T * A, IC, H, W)
        preskip = x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x + preskip
        x = F.relu(self.conv3(x))
        c = context.view(N * T * A, CC, 1, 1).expand(N * T * A, CC, x.shape[-2], x.shape[-1])
        x = torch.cat([x, c], dim=1)
        preskip = x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        preskip = x = x + preskip
        x = F.relu(self.conv6(x))
        x = x + preskip
        x = F.relu(self.conv7(x))
        # Reshape back to NTA*11 format
        x = x.view(N, T, A, self.repr_dim, 1, 1)
        return x

    def _forward_11x11(self, input: torch.Tensor, context: torch.Tensor):
        N, T, A, IC, H, W, CC = self._check_shapes(input, context)
        x = input.reshape(N * T * A, IC, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x)) + x
        c = context.view(N * T * A, CC, 1, 1).expand(N * T * A, CC, x.shape[-2], x.shape[-1])
        x = torch.cat([x, c], dim=1)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x)) + x
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x)) + x
        x = F.relu(self.conv7(x))
        x = x.reshape(N, T, A, self.repr_dim, 1, 1)
        return x

    def _forward_20x20(self, input: torch.Tensor, context: torch.Tensor):
        N, T, A, IC, H, W, CC = self._check_shapes(input, context)
        x = input.reshape(N * T * A, IC, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x)) + x
        c = context.reshape(N * T * A, CC, 1, 1).expand(N * T * A, CC, x.shape[-2], x.shape[-1])
        x = torch.cat([x, c], dim=1)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x)) + x
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x)) + x
        x = F.relu(self.conv7(x))
        # Reshape back to NTA*11
        x = x.reshape(N, T, A, self.repr_dim, 1, 1)
        return x

    def _forward_14x36p(self, input: torch.Tensor, context: torch.Tensor):
        N, T, A, IC, H, W, CC = self._check_shapes(input, context)
        x = input.reshape(N * T * A, IC, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x)) + x
        c = context.reshape(N * T * A, CC, 1, 1).expand(N * T* A, CC, x.shape[-2], x.shape[-1])
        x = torch.cat([x, c], dim=1)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x)) + x
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x)) + x
        x = F.relu(self.conv7(x))
        x = x.reshape(N, T, A, self.repr_dim, 1, 1)
        return x


class ContextlessTowerEncoder(TowerEncoder):
    def __init__(self, input_size, in_channels, capacity=128, repr_dim=128, polar=False):
        super(ContextlessTowerEncoder, self).__init__(input_size=input_size, in_channels=in_channels,
                                                      capacity=capacity, repr_dim=repr_dim, context_channels=0,
                                                      polar=polar)

    def _forward_5x5(self, input: torch.Tensor, context: type(None) = None):
        N, T, A, IC, H, W = input.shape
        x = input.reshape(N * T * A, IC, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x)) + x
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x)) + x
        x = F.relu(self.conv5(x)) + x
        x = F.relu(self.conv6(x)) + x
        x = F.relu(self.conv7(x))
        # Reshape to NTA*11 format
        x = x.reshape(N, T, A, self.repr_dim, 1, 1)
        return x

    def _forward_20x20(self, input: torch.Tensor, context: type(None) = None):
        N, T, A, IC, H, W = input.shape
        x = input.reshape(N * T * A, IC, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x)) + x
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x)) + x
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x)) + x
        x = F.relu(self.conv7(x))
        # Reshape back to NTA*11
        x = x.reshape(N, T, A, self.repr_dim, 1, 1)
        return x

    def _forward_14x36p(self, input: torch.Tensor, context: type(None) = None):
        N, T, A, IC, H, W = input.shape
        context = torch.empty(N, T, A, 0, dtype=input.dtype, device=input.device)
        return super(ContextlessTowerEncoder, self)._forward_14x36p(input, context)

    def _forward_11x11(self, input: torch.Tensor, context: type(None) = None):
        N, T, A, IC, H, W = input.shape
        context = torch.empty(N, T, A, 0, dtype=input.dtype, device=input.device)
        return super(ContextlessTowerEncoder, self)._forward_11x11(input, context)


class SumAggregator(nn.Module):
    def forward(self, input: torch.Tensor):
        # Convert NTAC11 or NTAC to NTC by summing over the A axis
        if input.ndimension() == 6:
            N, T, A, C, H, W = input.shape
            assert H == W == 1
        elif input.ndimension() == 4:
            N, T, A, C = input.shape
        else:
            raise NotImplementedError
        return input.sum(2).view(N, T, C)


class DropSumAggregator(nn.Module):
    def __init__(self, keep_proba=0.5, full_trajectory_drop=False, normalize_at_eval=True):
        super(DropSumAggregator, self).__init__()
        self.keep_proba = keep_proba
        self.full_trajectory_drop = full_trajectory_drop
        self.normalize_at_eval = normalize_at_eval

    def forward(self, input: torch.Tensor):
        if input.ndimension() == 6:
            N, T, A, C, H, W = input.shape
            assert H == W == 1
        elif input.ndimension() == 4:
            N, T, A, C = input.shape
        else:
            raise NotImplementedError
        # Convert to NTAC
        input = input.view(N, T, A, C)
        if self.training:
            # Compute drop mask
            mask = (Bernoulli(probs=self.keep_proba)
                    .sample((N, (1 if self.full_trajectory_drop else T), A, 1))
                    .to(input.device))
            out = (input * mask).sum(2).reshape(N, T, C)
        else:
            out = input.sum(2).reshape(N, T, C) / (self.keep_proba if self.normalize_at_eval else 1)
        return out


class FoldAggregator(nn.Module):
    def forward(self, input: torch.Tensor):
        # Convert NTAC11 or NTAC to NT(AC) by folding agent and channels together
        if input.ndimension() == 6:
            N, T, A, C, H, W = input.shape
            assert H == W == 1
        elif input.ndimension() == 4:
            N, T, A, C = input.shape
        else:
            raise NotImplementedError
        return input.reshape(N, T, A * C)


class PositionalEncoding(nn.Module):
    def __init__(self, encoding_dim=16, position_dim=2, max_frequency=10000, normalize=False, trainable=False):
        super(PositionalEncoding, self).__init__()
        assert (encoding_dim % position_dim) == 0, "Encoding dim must be divisible by the position dim."
        assert ((encoding_dim // position_dim) % 2) == 0, "Encoding dim / postion dim must be even."
        self.encoding_dim = encoding_dim
        self.position_dim = position_dim
        self.max_frequency = max_frequency
        self.normalize = normalize
        self.trainable = trainable
        self._init_parameters()

    def _init_parameters(self):
        _exps = torch.arange(0, self.encoding_dim // self.position_dim, 2, dtype=torch.float)
        if self.trainable:
            # noinspection PyArgumentList
            _intervals = torch.nn.Parameter(torch.ones(self.encoding_dim // self.position_dim // 2))
            # noinspection PyArgumentList
            _min_val = torch.nn.Parameter(torch.tensor([0.], dtype=torch.float))
            # noinspection PyArgumentList
            _delta_val = torch.nn.Parameter(torch.tensor([_exps.max()], dtype=torch.float))
            self.register_parameter('_intervals', _intervals)
            self.register_parameter('_min_val', _min_val)
            self.register_parameter('_delta_val', _delta_val)

    def get_exponents(self, device=None):
        if self.trainable:
            # Make sure that the min val and delta val are positive
            min_val = torch.relu(self._min_val)
            delta_val = torch.clamp_min(self._delta_val, 1e-4)
            intervals = torch.cumsum(torch.softmax(self._intervals, 0), 0)
            exps = min_val + delta_val * intervals
            return exps
        else:
            return torch.arange(0, self.encoding_dim // self.position_dim, 2, dtype=torch.float, device=device)

    def forward(self, positions):
        # positions.shape = NTAD, where D = self.position_dim
        N, T, A, D = positions.shape
        assert D == self.position_dim
        # The final encoding.shape = NTAC, where C = self.encoding_dim,
        # but per input dimension, we get C // D encoding dimensions. Let C' = C // D.
        encoding_dim_per_dim = self.encoding_dim // D
        # exps is like `i` in Attention is All You Need.
        exps = self.get_exponents(device=positions.device)
        # Divisor is 10000^(i/encoding_dim), but reshaped for proper broadcasting
        divisors = torch.pow(self.max_frequency, (exps / encoding_dim_per_dim))[None, None, None, None, :]
        # pre_sinusoids is a NTAD(C'/2) tensor.
        pre_sinusoids = positions[:, :, :, :, None] / divisors
        # Apply sinusoids to obtain a NTADC' tensor.
        post_sinusoids = torch.cat([torch.sin(pre_sinusoids), torch.cos(pre_sinusoids)], dim=-1)
        # Now flatten the last two dimensions to obtain a NTAC tensor (remember C = D * C')
        encodings = post_sinusoids.reshape(N, T, A, self.encoding_dim)
        # Normalize if required
        if self.normalize:
            encodings = encodings / torch.norm(encodings, dim=-1, keepdim=True)
        # ... and return
        return encodings


class ActionEncoding(nn.Linear):
    def __init__(self, num_actions, encoding_dim=16):
        super(ActionEncoding, self).__init__(in_features=num_actions, out_features=encoding_dim)
        self.num_actions = num_actions
        self.encoding_dim = encoding_dim

    def forward(self, actions):
        # actions.shape = NTAC
        N, T, A, AC = actions.shape
        embedding = super(ActionEncoding, self).forward(actions.reshape(N * T * A, AC))
        _, EC = embedding.shape
        return embedding.reshape(N, T, A, EC)


if __name__ == '__main__':
    # penc = PositionalEncoding(8, normalize=True, trainable=True)
    # y = penc(torch.randn(1, 1, 1, 2))
    enc = ContextlessTowerEncoder((11, 11), 1)
    x = torch.randn(32, 128, 10, 1, 11, 11)
    y = enc(x)
    print(y.shape)
