import math
from typing import Any, Tuple

import numpy as np  # type: ignore
import torch
from torch import nn
from torch.distributions import Normal

from sngp.wide_resnet import wide_sn_resnet28_10_cifar

T = torch.Tensor


def random_features(rows: int, cols: int, stddev: float = 0.05, orthogonal: bool = False) -> T:
    if orthogonal:
        cols_sampled, c = 0, []
        while cols_sampled < cols:
            # qr only returns orthogonal Q's as an (N, N) square matrix
            c.append(stddev * torch.linalg.qr(torch.randn(rows, rows, requires_grad=False), mode="complete")[0])
            cols_sampled += rows

        w = torch.cat(c, dim=-1)[:, :cols]
        w = w * np.sqrt(cols)
        return w

    return stddev * torch.randn(rows, cols, requires_grad=False)


class SNGP(nn.Module):
    def __init__(self, base: nn.Module, num_classes: int, h_dim: int = 128, gp_in_dim: int = 128, gp_h_dim: int = 1024, s: float = 0.0000001, m: float = 0.999):
        super().__init__()
        self.base = base
        self.classes = num_classes
        self.gp_in_dim = gp_in_dim
        self.gp_h_dim = gp_h_dim
        self.s = s

        self.initializer_std = 0.05
        self.layers = base
        self.name = self.base.name  # type: ignore
        self.total = 0
        self.m = m

        self.register_buffer("rff_weight", random_features(gp_in_dim, gp_h_dim, stddev=self.initializer_std, orthogonal=True))
        self.register_buffer("rff_bias", torch.rand(gp_h_dim, requires_grad=False) * 2 * math.pi)
        # self.register_buffer("gp_input_projection", random_features(h_dim, gp_in_dim, stddev=self.initializer_std))
        self.beta = nn.Linear(gp_h_dim, num_classes, bias=False)

        self.prec: T
        self.cov: T

        self.register_buffer("prec", torch.stack([torch.eye(gp_h_dim, gp_h_dim, requires_grad=False) for _ in range(num_classes)]) * self.s)
        self.register_buffer("cov", torch.zeros(num_classes, gp_h_dim, gp_h_dim, requires_grad=False))
        self.ln = nn.LayerNorm(gp_in_dim)

    def down_project(self, x: T) -> T:
        return self.base.down_project(x)  # type: ignore

    def init_sigma_lambda(self) -> None:
        """reinitializes sigma and lambda to new values"""
        self.prec = torch.stack([torch.eye(self.gp_h_dim, self.gp_h_dim, requires_grad=False, device=self.rff_weight.device) for _ in range(self.classes)])  # type: ignore
        self.prec *= self.s  # type: ignore
        self.cov = torch.zeros((self.classes, self.gp_h_dim, self.gp_h_dim), device=self.rff_weight.device, requires_grad=False)  # type: ignore

    def rff(self, x: T) -> T:
        # x = x @ self.gp_input_projection
        x = self.ln(x)
        # gp inpout scale: https://github.com/google/uncertainty-baselines/blob/main/baselines/cifar/sngp.py#L134
        # in the GP layer code: https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L131
        x = x * (1 / np.sqrt(2))
        x = x @ self.rff_weight + self.rff_bias
        x = np.sqrt(2 / self.gp_h_dim) * torch.cos(x)
        return x

    def forward(self, x: T, update_prec: bool = False, y: T = None) -> T:
        """during training, there will be no updating of the precision matrix, just normal protonet"""
        x = self.layers(x)
        phi = self.rff(x)
        logits = self.beta(phi)

        if update_prec:
            if y is None:
                raise ValueError("y cannot be None when updating the Hessian")
            self.update_lambda(phi, y, logits)
        return logits  # type: ignore

    def update_lambda(self, _phi: T, y: T, logits: T) -> None:
        probs = logits.softmax(dim=-1)
        for i in range(self.classes):
            p = probs[:, i:i + 1]
            phi = _phi * torch.sqrt(p * (1 - p))
            self.prec[i] += phi.t() @ phi
            # self.prec[i] = self.m * self.prec[i] + (1 - self.m) * ((phi.t() @ phi) / ((y == i).sum() + 1e-6))

    def compute_cov(self) -> None:
        for i in range(self.classes):
            self.cov[i] = torch.inverse(self.prec[i])

    def compute_posterior(self, x: T) -> Tuple[T, T]:
        phi = self.rff(self.layers(x))
        mu = self.beta(phi)
        covar = torch.einsum("cij,bj->cbi", self.cov, phi)
        covar = covar @ phi.t()
        return mu, torch.sqrt(torch.diagonal(covar, dim1=1, dim2=2)).t()

    def mc(self, x: T, samples: int = 100) -> T:  # type: ignore
        mu, sigma = self.compute_posterior(x)
        yhat = Normal(mu, sigma).sample((samples,)).detach().cpu()
        logits = torch.logsumexp(yhat, dim=0) - np.log(samples)
        yhat = yhat.softmax(dim=-1).mean(dim=0)
        return yhat, logits  # type: ignore


def SNGP_WideResNet28_10_cifar(resnet_kwargs: Any = {}, sngp_kwargs: Any = {}) -> SNGP:
    net = wide_sn_resnet28_10_cifar(**resnet_kwargs)
    return SNGP(net, h_dim=640, **sngp_kwargs)
