from functools import reduce
import inspect
from typing import Tuple

import torch
from torch.nn import *
from torch.nn.modules.flatten import Flatten


class View(Module):
    def __init__(self, shape:tuple):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(*((batch_size,) + self.shape))


class FlattenedCoding(Module):

    def __init__(self, shape:tuple):
        super().__init__()
        self.flattened = reduce(lambda x, y: x*y, shape)

    def forward(self, _):
        raise NotImplementedError


class UnknownEncoder(Exception):
    pass


class UnknownDecoder(Exception):
    pass


class SensorEncoder(Module):

    def __init__(self, shape:tuple, code_size:int, with_residuals:bool=False):
        super().__init__()
        dims = len(shape[1:])
        try:
            self.delegate = globals()[f'SensorConv{dims}DEncoder'](shape, code_size, with_residuals=with_residuals)
            self.unflatten_output_size = self.delegate.unflatten_output_size
        except KeyError as e:
            raise UnknownEncoder(e)

    def forward(self, signal):
        return self.delegate.forward(signal)


class SensorDecoder(Module):

    def __init__(self, shape:tuple, code_size:int, encoder_unflatten_output_size:tuple, with_residuals:bool=False):
        super().__init__()
        dims = len(shape[1:])
        try:
            self.delegate = globals()[f'SensorConv{dims}DDecoder'](shape, code_size, encoder_unflatten_output_size, with_residuals=with_residuals)
        except KeyError as e:
            raise UnknownDecoder(e)

    def forward(self, signal):
        return self.delegate.forward(signal)


class _SensorConvXDCoder(Module):

    def __init__(self):
        super().__init__()
        self.unflatten_output_size = None

    def forward(self, signal):
        return self.coder(signal)

    def _compute_out_size(self, channels:int, dims:Tuple[int], convs:list, pools:list) -> tuple:
        for c, p in zip(convs, pools):
            tmp = []
            for dim in dims:
                tmp.append(((dim + 2 * c['p'] - c['d'] * (c['ks'] - 1) - 1) // c['s']) + 1)
            dims = tuple(tmp)
            tmp = []
            for dim in dims:
                tmp.append(((dim + 2 * p['p'] - p['d'] * (p['ks'] - 1) - 1) // p['s']) + 1)
            dims = tuple(tmp)
        return tuple([channels]) + dims

    def _compute_dense_in(self, channels:int, dims:Tuple[int], convs:list, pools:list) -> int:
        self.unflatten_output_size = self._compute_out_size(channels, dims, convs, pools)
        return reduce(lambda x,y:x*y, self.unflatten_output_size)


class SensorConv3DEncoder(_SensorConvXDCoder):

    def __init__(self, shape:tuple, code_size:int, with_residuals:bool=False):
        super().__init__()
        self.coder = IndexTrackingSequential(
            self._conv(shape[0], shape[0] * 32),
            ReLU(True),
            self._pool(),
            self._conv(shape[0] * 32, shape[0] * 16),
            ReLU(True),
            self._pool(),
            self._conv(shape[0] * 16, shape[0] * 8),
            ReLU(True),
            self._pool(),
            self._conv(shape[0] * 8, shape[0] * 4),
            ReLU(True),
            self._pool(),
            Flatten(),
            Linear(
                self._compute_dense_in(shape[0]*4, shape[1:], [{'d':1,'ks':3,'p':1,'s':1}]*4, [{'d':1,'s':2,'p':1,'ks':2}]*4),
                code_size,
                bias=False
            ),
            with_residuals=with_residuals
        )

    def _conv(self, i, o):
        return Conv3d(i, o, 3,
                      stride=1, padding=1,
                      dilation=1, groups=1, bias=True,
                      padding_mode='zeros')

    def _pool(self):
        return MaxPool3d(2,
                         stride=None, padding=1,
                         dilation=1,
                         return_indices=True,
                         ceil_mode=False)


class SensorConv3DDecoder(_SensorConvXDCoder):

    def __init__(self, shape:tuple, code_size:int, encoder_unflatten_output_size:tuple, with_residuals:bool=False):
        super().__init__()
        self.coder = IndexTrackingSequential(
            Linear(
                code_size,
                reduce(lambda x, y: x*y, encoder_unflatten_output_size),
                bias=False
            ),
            View(encoder_unflatten_output_size),
            self._pool(),
            self._conv(shape[0] * 4, shape[0] * 8),
            ReLU(True),
            self._pool(),
            self._conv(shape[0] * 8, shape[0] * 16),
            ReLU(True),
            self._pool(),
            self._conv(shape[0] * 16, shape[0] * 32),
            ReLU(True),
            self._pool(),
            self._conv(shape[0] * 32, shape[0]),
            Sigmoid(),
            with_residuals=with_residuals
        )

    def _conv(self, i, o):
        return ConvTranspose3d(i, o, 3,
                               stride=1, padding=1,
                               output_padding=0, groups=1,
                               bias=True, dilation=1,
                               padding_mode='zeros')

    def _pool(self):
        return MaxUnpool3d(2,
                           stride=None,
                           padding=1)


class SensorConv2DEncoder(_SensorConvXDCoder):

    def __init__(self, shape:tuple, code_size:int, with_residuals:bool=False):
        super().__init__()
        self.coder = IndexTrackingSequential(
            Conv2d(shape[0], shape[0] * 16, 3,
                   stride=1, padding=1,
                   dilation=1, groups=1, bias=True,
                   padding_mode='zeros'),
            ReLU(True),
            self._pool(),
            Conv2d(shape[0] * 16, shape[0] * 8, 3,
                   stride=1, padding=1,
                   dilation=1, groups=1, bias=True,
                   padding_mode='zeros'),
            ReLU(True),
            self._pool(),
            Conv2d(shape[0] * 8, shape[0] * 4, 3,
                   stride=1, padding=1,
                   dilation=1, groups=1, bias=True,
                   padding_mode='zeros'),
            ReLU(True),
            self._pool(),
            Flatten(),
            Linear(
                self._compute_dense_in(shape[0] * 4, shape[1:], [{'d':1,'ks':3,'p':1,'s':1}]*3, [{'d':1,'s':2,'p':0,'ks':2}]*3),
                code_size,
                bias=False
            ),
            with_residuals=with_residuals
        )

    def _pool(self):
        return MaxPool2d(2,
                         stride=None, padding=0,
                         dilation=1,
                         return_indices=True,
                         ceil_mode=False)


class SensorConv2DDecoder(_SensorConvXDCoder):

    def __init__(self, shape:tuple, code_size:int, encoder_unflatten_output_size:tuple, with_residuals:bool=False):
        super().__init__()
        self.coder = IndexTrackingSequential(
            Linear(
                code_size,
                reduce(lambda x, y: x*y, encoder_unflatten_output_size),
                bias=False
            ),
            View(encoder_unflatten_output_size),
            MaxUnpool2d(2,
                        stride=None,
                        padding=0),
            ConvTranspose2d(shape[0] * 4, shape[0] * 8, 3,
                            stride=1, padding=1,
                            output_padding=0, groups=1,
                            bias=True, dilation=1,
                            padding_mode='zeros'),
            ReLU(True),
            MaxUnpool2d(2,
                        stride=None,
                        padding=0),
            ConvTranspose2d(shape[0] * 8, shape[0] * 16, 3,
                            stride=1, padding=1,
                            output_padding=0, groups=1,
                            bias=True, dilation=1,
                            padding_mode='zeros'),
            ReLU(True),
            MaxUnpool2d(2,
                        stride=None,
                        padding=0),
            ConvTranspose2d(shape[0] * 16, shape[0], 3,
                            stride=1, padding=1,
                            output_padding=0, groups=1,
                            bias=True, dilation=1,
                            padding_mode='zeros'),
            Sigmoid(),
            with_residuals=with_residuals
        )


class FlatteningEncoder(FlattenedCoding):

    def __init__(self, shape:tuple, code_size:int):
        super().__init__(shape)
        self.flatten = Flatten()
        self.encoder = Linear(
            self.flattened,
            code_size,
            bias=False
        )

    def forward(self, signal):
        return Sequential(
            self.flatten,
            self.encoder
        )(signal)


class FlatteningDecoder(FlattenedCoding):

    def __init__(self, shape:tuple, code_size:int):
        super().__init__(shape)
        self.decoder = Linear(
            code_size,
            self.flattened,
            bias=False
        )
        self.unflatten = View(shape)

    def forward(self, signal):
        return Sequential(
            self.decoder,
            self.unflatten,
            Sigmoid()
        )(signal)


class IndexTrackingSequential(Sequential):

    def __init__(self, *seq, with_residuals=False):
        super().__init__(*seq)
        self.with_residuals = with_residuals

    def forward(self, feed) -> Tuple[torch.Tensor, dict]:
        try:
            data, index_map, target_shape = feed
        except ValueError:
            data = feed
        index = []
        if self.with_residuals:
            residual = data.detach().clone()
        for module in self:
            if self._requires_indices(module):
                indices = index_map.pop()
                if len(index_map) > 0:
                    data = module(data, indices, output_size=index_map[-1].size())
                else:
                    data = module(data, indices, output_size=target_shape)
            else:
                data = module(data)
                if type(data) is tuple:
                    data, indices = data
                    index.append(indices)
            if self.with_residuals:
                if module is ReLU:
                    data = data + residual
                elif module is torch.nn.modules.pooling._MaxPoolNd:
                    sf = module.kernel_size
                    # TODO Lazy---pragmatic?---does not cover tuples, and ignores
                    # dilatation, etc at this point.
                    if module.stride is int:
                        sf /= module.stride
                    residual = functional.interpolate(residual, scale_factor=sf)
        return data, index

    def _requires_indices(self, mod):
        '''
        Naive: This will break anytime the signature changes.
        '''
        sig = inspect.signature(mod.forward)
        return 'indices' in sig.parameters
