from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.modules.utils import _pair


class Conv2d(nn.Conv2d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Optional[Union[int, Tuple[int, int]]] = None,
        groups: int = 1,
        bias: bool = True,
        num_lc_iter: int = 10,
        input_size: Union[int, Tuple[int, int]] = 32,
        orthogonal_init: bool = True,
        use_fln: bool = False,
        **kwargs,
    ) -> None:
        if padding is None:
            if type(kernel_size) == int:
                padding = kernel_size // 2
            elif len(kernel_size) == 2:
                padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        super(Conv2d, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
        )

        if orthogonal_init:
            nn.init.orthogonal_(self.weight)

        self.num_lc_iter = num_lc_iter
        self.input_size = _pair(input_size)

        init_x = torch.ones(1, self.in_channels, *self.input_size)
        self.register_buffer("init_x", init_x)

        self.output_padding = self.compute_output_padding()
        self.use_fln = use_fln
        if use_fln:
            self.lc = None

    def lipschitz(self) -> Tensor:
        x = self.init_x.data
        for _ in range(self.num_lc_iter):
            x = F.conv2d(
                x,
                self.weight,
                bias=None,
                stride=self.stride,
                padding=self.padding,
                groups=self.groups,
            )
            x = F.conv_transpose2d(
                x,
                self.weight,
                bias=None,
                stride=self.stride,
                padding=self.padding,
                output_padding=self.output_padding,
                groups=self.groups,
            )
            x = F.normalize(x, dim=(1, 2, 3))

        x = x.detach()
        self.init_x += (x - self.init_x).detach()
        x = F.conv2d(
            x,
            self.weight,
            bias=None,
            stride=self.stride,
            padding=self.padding,
            groups=self.groups,
        )
        lc = x.norm()
        if self.use_fln:
            self.lc = lc
            return torch.tensor(1.0)
        return lc

    def compute_output_padding(self) -> Tuple:
        h, w = self.input_size
        s1, s2 = _pair(self.stride)
        k1, k2 = _pair(self.kernel_size)
        p1, p2 = _pair(self.padding)
        o1 = (h + 2 * p1 - k1) // s1 * s1 - 2 * p1 + k1
        o2 = (w + 2 * p2 - k2) // s2 * s2 - 2 * p2 + k2
        return h - o1, w - o2

    def get_weight(self) -> Tensor:
        if self.use_fln:
            if self.lc is None:
                raise ValueError("The Lipschitz constant is not computed yet.")
            weight = self.weight / self.lc
        else:
            weight = self.weight
        return weight

    # overwrite the forward method to use the spectral normalization
    def forward(self, input: Tensor) -> Tensor:
        weight = self.get_weight()
        return self._conv_forward(input, weight, self.bias)
