import torch
import torch.nn as nn
import torch.nn.functional as F


class Map2Vec(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        feat_size: int,
        out_dim: int,
        activation: nn.Module,
        conv_type: str = "bro",
        linear_type: str = "bro",
        input_size: int = 16,
    ) -> None:
        super(Map2Vec, self).__init__()
        assert feat_size % 4 == 0
        self.conv_type = conv_type
        self.linear_type = linear_type
        mid_size = feat_size // 4
        mid_dim = feat_dim * mid_size**2

        if conv_type == "bro":
            kernel = torch.randn(feat_dim * 16, feat_dim)
        else:
            kernel = torch.randn(feat_dim, feat_dim * 16)

        kernel = kernel / feat_dim**0.5 / 4
        self.kernel = nn.Parameter(kernel)

        if linear_type == "bro":
            weight = torch.randn(mid_dim, out_dim) / mid_dim**0.5
        else:
            weight = torch.randn(out_dim, mid_dim) / mid_dim**0.5
        self.linear_out_dim = out_dim
        self.linear_in_dim = mid_dim
        self.weight = nn.Parameter(weight)

        self.bias = nn.Parameter(torch.zeros(out_dim))

        self.feat_dim = feat_dim
        self.mid_size = mid_size
        self.activation = activation
        self.input_size = input_size

    def get_weight(self):
        if self.conv_type == "cholesky":
            Sigma = self.kernel @ self.kernel.T
            eps = Sigma.diag().mean().div(1000.0).item()
            Sigma = Sigma + eps * torch.eye(
                Sigma.shape[0], device=Sigma.device, dtype=Sigma.dtype
            )
            L = torch.linalg.cholesky(Sigma)
            kernel_ = torch.linalg.solve_triangular(L, self.kernel, upper=False)
        elif self.conv_type == "bro":
            assert self.kernel.shape[0] > self.kernel.shape[1]
            Sigma = self.kernel.T @ self.kernel
            eps = Sigma.diag().mean().mul(1e-5).item()
            Sigma = Sigma + eps * torch.eye(
                Sigma.shape[0], device=Sigma.device, dtype=Sigma.dtype
            )
            kernel_ = torch.eye(
                self.kernel.shape[0], device=self.kernel.device, dtype=self.kernel.dtype
            ) - 2 * self.kernel @ torch.linalg.solve(Sigma, self.kernel.T)
            kernel_ = kernel_[: self.feat_dim, : self.feat_dim * 16]
        else:
            raise ValueError("Unsupported `conv` type!")

        kernel_ = kernel_.reshape(self.feat_dim, self.feat_dim, 4, 4)

        if self.linear_type == "cholesky":
            Sigma = self.weight @ self.weight.T
            eps = Sigma.diag().mean().div(1000.0).item()
            Sigma = Sigma + eps * torch.eye(
                Sigma.shape[0], device=Sigma.device, dtype=Sigma.dtype
            )
            L = torch.linalg.cholesky(Sigma)
            weight_ = torch.linalg.solve_triangular(L, self.weight, upper=False)
        elif self.linear_type == "bro":
            assert self.weight.shape[0] > self.weight.shape[1]
            Sigma = self.weight.T @ self.weight
            eps = Sigma.diag().mean().mul(1e-5).item()
            Sigma = Sigma + eps * torch.eye(
                Sigma.shape[0], device=Sigma.device, dtype=Sigma.dtype
            )
            weight_ = torch.eye(
                self.weight.shape[0], device=self.weight.device, dtype=self.weight.dtype
            ) - 2 * self.weight @ torch.linalg.solve(Sigma, self.weight.T)
            weight_ = weight_[: self.linear_out_dim, : self.linear_in_dim]
        else:
            raise ValueError("Unsupported `linear` type!")

        return kernel_, weight_

    def forward(self, x):
        kernel, weight = self.get_weight()
        x = F.conv2d(x, kernel, stride=4)
        x = x.reshape(x.shape[0], -1)
        x = F.linear(x, weight, self.bias)
        x = self.activation(x)
        return x

    def lipschitz(self):
        if self.training:
            return 1.0

        kernel, weight = self.get_weight()
        kernel = kernel.reshape(self.feat_dim, -1)
        lc = kernel.svd().S.max()

        # kernel = kernel.reshape(self.feat_dim, self.feat_dim, 4, 4)
        # self.padded_input_size = self.input_size // 4
        # kernel_fft = torch.fft.fft2(
        #     kernel, (self.padded_input_size, self.padded_input_size), norm="ortho"
        # )
        # kernel_fft = kernel_fft.permute(2, 3, 0, 1)
        # S = kernel_fft.svd().S
        # print(f"kernel LC by fft: {S.max().item()}")
        # exit()

        lc = lc * weight.svd().S.max()
        return lc.item()

    def extra_repr(self) -> str:
        return f"conv_type={self.conv_type}, channel={self.feat_dim}, linear_type={self.linear_type} in={self.linear_in_dim} out={self.linear_out_dim}."
