import os
from typing import Any, List

import torch
from layers import LinearResidual, UpProjection
from torch import nn

from sngp.model import SNGP

T = torch.Tensor


class LinearSNGP(nn.Module):
    def __init__(self, n_layers: int, in_dim: int, h_dim: int, p: float = 0.01, num_classes: int = 10, ctype: str = "none", c: float = 1.0) -> None:
        super().__init__()
        self.name = os.path.join("LinearSNGP", f"{n_layers}-layers")
        self.ctype = ctype
        self.c = c

        lyrs: List[Any] = [nn.Linear(in_dim, h_dim, bias=True)]
        # lyrs: List[Any] = [UpProjection(in_dim, h_dim)]
        for i in range(n_layers):
            lyrs.extend([LinearResidual(h_dim, p=p, spectral=True, ctype=ctype)])

        self.layers = nn.Sequential(*lyrs)

    def down_project(self, x: T) -> T:
        """take the input and project it down to 2 dimensions which should preserve the spatial distance"""
        x = self.layers(x)
        return x @ torch.randn(x.size(-1), 2, device=x.device)

    def forward(self, x: T) -> T:
        return self.layers(x)  # type: ignore


def sngp_linear(n_layers: int, in_dim: int, h_dim: int, p: int, num_classes: int, ctype: str, s: float) -> SNGP:
    return SNGP(LinearSNGP(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, p=p, num_classes=num_classes, ctype=ctype), num_classes, h_dim, h_dim, s=s)
