from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch import Tensor

import math

from torch.nn import Module, Linear, ReLU, Softmax, Sequential
from torch.nn import init
import torch

class MLPMUP(Module):
    """Implement a basic MLP using maximal update parameterization."""

    def __init__(
            self,
            in_dim: int,
            out_dim: int,
            widths: int|list[int],
            *,
            is_bias: bool=True,
            is_classification: bool=False,
    ) -> None:
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.widths = widths
        self.is_bias = is_bias

        if isinstance(self.widths, int):
            self.widths = [self.widths]

        self.mods = self._build()
        out_scale = torch.sqrt(
            torch.tensor(
                1/self.widths[-1],
            ),
        )
        self.mods.append(
            ScaledLayer(
                scale=out_scale,
                in_features=self.widths[-1],
                out_features=self.out_dim,
                is_bias=self.is_bias,
            )
        )
        if is_classification:
            self.mods.append(
                Softmax(dim=1),
            )
        self.net = Sequential(*self.mods)

        self.reset_parameters()

    def forward(
            self,
            x: Tensor,
    ) -> Tensor:
        y = self.net(x)
        return y
    
    def reset_parameters(
            self,
    ) -> None:
        for name, module in self.named_modules():
                if "layer" in name:
                    init.normal_(
                        module.weight,
                        std=1/math.sqrt(module.weight.shape[1]),
                    )

    def _build(
            self,
    )-> list[Module]:
        mods = []
        in_scale = torch.sqrt(
            torch.tensor(
                self.in_dim,
            ),
        )
        mods.append(
            ScaledLayer(
                scale=in_scale,
                in_features=self.in_dim,
                out_features=self.widths[0],
                is_bias=self.is_bias,
            ),
        )
        mods.append(ReLU())

        for w in range(len(self.widths)-1):
            mods.append(
                ScaledLayer(
                    scale=torch.tensor(1.),
                    in_features=self.widths[w],
                    out_features=self.widths[w+1],
                    is_bias=self.is_bias,
                ),
            )
            mods.append(ReLU())
        
        return mods


class ScaledLayer(Module):
    def __init__(
            self,
            scale: Tensor,
            in_features: int,
            out_features: int,
            *,
            is_bias: bool,
    )-> None:
        super().__init__()
        self.scale = scale
        self.layer = Linear(
            in_features=in_features,
            out_features=out_features,
            bias=is_bias,
        )

    def forward(
            self,
            x: Tensor,
    ) -> Tensor:
        if len(x.shape) > 2:
            x = x.squeeze(dim=1)
            x = x.view((x.shape[0], x.shape[1] * x.shape[2]))
        y = self.layer(x)
        y *= self.scale
        return y