from typing import Tuple, List

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

from vae.misc import get_trainable_params


class BaseModule(nn.Module):

    def __init__(self):
        super().__init__()
        # define in subclass
        self._input_shape = None
        self._output_shape = None

    @property
    def input_shape(self) -> Tuple[int]:
        return self._input_shape

    @property
    def output_shape(self) -> Tuple[int]:
        return self._output_shape

    def __str__(self, additional_info=True):
        s = f'{super().__str__()}'
        if additional_info:
            s = s + f'\nTrainable parameters: ' \
                    f'{get_trainable_params(self):,}\n' \
                    f'Input shape: {self.input_shape}\n' \
                    f'Output shape: {self.output_shape}\n'
        return s


class ProductOfExperts(nn.Module):
    """Return parameters for product of independent experts.
    See https://arxiv.org/pdf/1410.7827.pdf for equations.
    """

    def forward(self, mus: List[Tensor], scales: List[Tensor], eps=1e-8):
        """
        M is number of experts, i.e. number of modalities plus the prior
        expert.
        :param mus: list of length M, each entry with shape N x D
        :param scales: list of length M, each entry with shape N x D
        :return:
        """
        mus = torch.stack(mus)  # M x N x D
        scales = torch.stack(scales)  # M x N x D
        variances = scales ** 2  # scale is standard deviation

        t = 1 / (variances + eps)  # precision of i-th Gaussian expert at point x
        pd_mus = torch.sum(mus * t, dim=0) / torch.sum(t, dim=0)
        pd_vars = 1 / torch.sum(t, dim=0)
        pd_scales = pd_vars.sqrt()

        return pd_mus, pd_scales


class Swish(nn.Module):
    """https://arxiv.org/abs/1710.05941"""

    def forward(self, x):
        return x * torch.sigmoid(x)


class UpsampleSpatialDimensions(BaseModule):
    """ Upsamples spatial dimensions by factor of two
    Inspired by https://arxiv.org/abs/2011.10650
    """

    def __init__(self, input_shape):
        super().__init__()
        assert len(input_shape) == 3, "Assume C x H x W"
        self._input_shape = input_shape
        c, d = check_input_shape(input_shape)
        self._output_shape = (c, d * 2, d * 2)

    def forward(self, x):
        interpolate = lambda x: F.interpolate(x, scale_factor=2)
        x = forward_with_additional_batch_dimensions(interpolate, x)
        return x


def check_input_shape(input_shape):
    """ Checks input shape for constructing convolutional models. """
    assert len(input_shape) == 3, 'Assume 2D images'
    c, d1, d2 = input_shape
    assert d1 == d2, 'Assume spatial dimensions to be identical'
    d = d1
    return c, d


def reshape(x, shape: tuple):
    """
    Reshapes any tensor to desired shape.
    """
    if len(x.size()) == 2:
        # assume (bs, d)
        x = x.view(-1, *shape)
    elif len(x.size()) == 3:
        # assume (k, bs, d)
        x = x.view(x.size(0), x.size(1), *shape)
    elif len(x.size()) == 4:
        # assume (bs, c, h, w)
        x = x.view(-1, *shape)
    elif len(x.size()) == 5:
        # assume (k, bs, c, h, w)
        x = x.view(x.size(0), x.size(1), *shape)
    else:
        raise ValueError(f'Please provide data with known length, '
                         f'current length is {len(x.size())}')
    return x


def forward_with_additional_batch_dimensions(layer, x):
    """
    Allows using additional batch dimensions for layer that expects N x C x H x W
    """
    # Flatten tensor
    n = len(x.size())
    batch_dims = None
    if n > 4:
        batch_dims = x.size()[:n - 3]
        x = x.view(-1, *x.size()[-3:])
    # Computation
    x = layer(x)
    # Recover batch dimensions
    if batch_dims:
        x = x.view(*batch_dims, *x.size()[1:])
    return x
