from typing import Any, List, Optional, Tuple

import torch
from torch import nn
from torch.nn import functional as F

from layers.spectral_norm import SpectralNorm

T = torch.Tensor


def NoOp(x: T, *args: Any, **kwargs: Any) -> T:
    return x


def maxpool(x: T) -> T:
    return F.max_pool2d(x, 2, padding=0)  # type: ignore


class UpProjection(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, static: bool = True, std: float = 0.05, bias: bool = True) -> None:
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.static = static

        if static:
            self.register_buffer("layer", torch.randn(in_dim, out_dim, requires_grad=False) * std)
        else:
            self.layer = nn.Linear(in_dim, out_dim, bias=bias)

    def __str__(self) -> str:
        return f"UpProjection(in_dim: {self.in_dim}, out_dim: {self.out_dim})"

    def forward(self, x: T) -> T:
        if self.static:
            return x @ self.layer
        return self.layer(x)  # type: ignore


class LinearResidual(nn.Module):
    def __init__(
        self,
        dim: int,
        p: float = 0.1,
        activation: Any = nn.ReLU,
        spectral: bool = False,
        c: float = 1.0,
        layernorm: bool = False,
        bias: bool = True,
        ctype: str = "error"
    ) -> None:
        """bn tracks the running stats but the main model should reinit BN layers at every new train task"""
        super().__init__()

        self.dim = dim
        self.spectral = spectral

        lyr: Any = []
        if spectral:
            lyr.append(SpectralNorm(nn.Linear(dim, dim, bias=bias), ctype=ctype, c=c))
        else:
            lyr.append(nn.Linear(dim, dim))

        if layernorm:
            lyr.append(nn.LayerNorm(dim))

        lyr.append(activation())

        if p:
            lyr.append(nn.Dropout(p))

        self.layer = nn.Sequential(*lyr)

    def forward(self, x: T) -> T:
        if self.spectral:
            return x + self.layer(x)  # type: ignore
        return x + self.layer(x)  # type: ignore
