"""
Simple, smooth maps R^n -> R^m.
"""
import math
from abc import abstractmethod

import numpy as np
import torch.nn as nn

from .base import Map


class SmoothMap(Map):
    """Base class for simple smooth maps."""

    def _listify_hypers(self, num_layers, hyperparams: dict):
        for name, value in hyperparams.items():
            if isinstance(value, int):
                self.__setattr__(name, [value] * num_layers)
            else:
                assert len(value) == num_layers
                self.__setattr__(name, value)

    def _apply_spectral_norm(self):
        for module in self.modules():
            if "weight" in module._parameters:
                nn.utils.spectral_norm(module)

    def forward(self, x):
        return self.net(x)


class FlatSmoothMap(SmoothMap):
    """Smooth map for flat data."""

    def __init__(self, dom_dim, codom_dim, num_layers=3, hidden_size=32, spectral_norm=False):
        super().__init__()

        self.dom_dim = dom_dim
        self.codom_dim = codom_dim

        self._listify_hypers(num_layers, {"hidden_size": hidden_size})

        layers = []
        prev_size = dom_dim
        for size in self.hidden_size:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.SiLU())
            prev_size = size

        layers.append(nn.Linear(prev_size, codom_dim))

        self.net = nn.Sequential(*layers)

        if spectral_norm:
            self._apply_spectral_norm()



class ImageSmoothMap(SmoothMap):
    """Smooth map for image data."""

    def __init__(self, dom_shape, codom_dim, num_layers=3, hidden_channels=32, kernel_size=3,
                 pool_size=2, spectral_norm=False):
        super().__init__()

        self.dom_shape = dom_shape
        self.dom_dim = int(np.prod(dom_shape))
        self.codom_dim = codom_dim

        self._listify_hypers(
            num_layers,
            {
                "hidden_channels": hidden_channels,
                "kernel_size": kernel_size,
                "pool_size": pool_size
            }
        )

        prev_channels, height, width = dom_shape
        assert height == width, "Network assumes square image"
        del width

        layers = []

        for channels, k, p in zip(self.hidden_channels, self.kernel_size, self.pool_size):
            layers.append(nn.Conv2d(prev_channels, hidden_channels, k))
            layers.append(nn.SiLU())
            layers.append(nn.AvgPool2d(p))

            height = self._get_new_height(height, k, 1) # Get height after conv
            height = self._get_new_height(height, p, p) # Get height after pool

            prev_channels = channels

        layers.extend([
            nn.Flatten(),
            nn.Linear(channels*height**2, codom_dim),
        ])

        self.net = nn.Sequential(*layers)

        if spectral_norm:
            self._apply_spectral_norm()

    @staticmethod
    def _get_new_height(height, kernel, stride):
        # cf. https://pytorch.org/docs/1.9.1/generated/torch.nn.Conv2d.html
        # Assume dilation = 1, padding = 0
        return math.floor((height - kernel)/stride + 1)