from typing import Any, Tuple

import torch
import torch.nn as nn
from layers import LinearResidual, SpectralNorm, SpectralNormConv, UpProjection

T = torch.Tensor

__all__ = ["FCMixin", "ConvSpectralResidualProto", "CNN4ResidualSpectralMixin", "DimTuple"]


class FCMixin(nn.Module):
    def __init__(
        self,
        n_layers: int = 6,
        in_dim: int = 2,
        h_dim: int = 128,
        classes: int = 2,
        p: float = 0.01,
        ctype: str = "error",
        spectral: bool = True,
        c: float = 1.0,
        **kwargs: Any
    ) -> None:
        super().__init__()
        # lyrs: Any = [SpectralNorm(nn.Linear(in_dim, h_dim, bias=False), ctype=ctype, c=c)]
        lyrs: Any = [UpProjection(in_dim, h_dim, static=False, bias=True)]
        # lyrs: Any = [UpProjection(in_dim, h_dim, static=True, std=1.0)]
        for i in range(n_layers):
            lyrs.extend([LinearResidual(h_dim, p=p, spectral=spectral, c=c, activation=nn.ReLU, ctype=ctype)])

        self.layers = nn.ModuleList(lyrs)


class ConvSpectralResidualProto(nn.Module):
    def __init__(
        self,
        dims: Tuple[int, ...],
        in_ch: int = 1,
        filters: int = 64,
        padding: int = 1,
        p: float = 0.01,
        activation: Any = nn.ReLU,
        ctype: str = "error",
        c: int = 3,
        spectral: bool = True,
        **kwargs: Any
    ):
        super().__init__()
        self.padding = padding

        conv: nn.Module
        conv = SpectralNormConv(nn.Conv2d(in_ch, filters, 3, padding=self.padding, bias=False), dims, c=c, ctype=ctype)
        if not spectral:
            conv = nn.Conv2d(in_ch, filters, 3, padding=self.padding, bias=False)

        self.layer = nn.Sequential(
            conv,
            nn.BatchNorm2d(filters),
            nn.ReLU(),
            nn.Dropout2d(p=p)
        )

        self.upsample = nn.Identity() if in_ch == filters else SpectralNormConv(nn.Conv2d(in_ch, filters, 1, bias=False), dim=dims, c=c, ctype=ctype)
        if not spectral:
            self.upsample = nn.Identity() if in_ch == filters else nn.Conv2d(in_ch, filters, 1, bias=False)

        self.pool = nn.AvgPool2d(2)

    def forward(self, x: T) -> T:
        identity = self.upsample(x)
        fx = self.layer(x)
        return self.pool(identity + fx)  # type: ignore


DimTuple = Tuple[Tuple[int, ...], ...]


class CNN4ResidualSpectralMixin(nn.Module):
    def __init__(self, dims: DimTuple = (()), in_ch: int = 1, h_dim: int = 64, ctype: str = "error", c: int = 3, spectral: bool = True, **kwargs: Any) -> None:
        super().__init__()
        if len(dims) != 4:
            raise ValueError(f"need 4 dims for ConvSpectralResidualProto: ({len(dims)=})")

        self.layers = nn.Sequential(
            ConvSpectralResidualProto(dims[0], in_ch, h_dim, c=c, ctype=ctype, spectral=spectral),
            ConvSpectralResidualProto(dims[1], h_dim, h_dim, c=c, ctype=ctype, spectral=spectral),
            ConvSpectralResidualProto(dims[2], h_dim, h_dim, c=c, ctype=ctype, spectral=spectral),
            ConvSpectralResidualProto(dims[3], h_dim, h_dim, c=c, ctype=ctype, spectral=spectral),
            nn.Flatten()
        )
