import math
from typing import Any, List, Tuple

import numpy as np  # type: ignore
import torch
from layers.spectral_norm import SpectralNorm
from mahalanobis.models.mixins import (CNN4ResidualSpectralMixin, DimTuple,
                                       FCMixin)
from mahalanobis.models.proto_mahalanobis import mahalanobis_distance
from mahalanobis.models.protonet import Protonet
from mahalanobis.models.set_xformer import \
    batched_sherman_morrison_rank_one_inverse
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as F
from utils import softmax_log_softmax_of_sample

T = torch.Tensor

__all__ = ["proto_sngp_linear", "proto_sngp_cnn"]


def random_features(rows: int, cols: int, stddev: float = 0.05, orthogonal: bool = False) -> T:
    if orthogonal:
        cols_sampled, c = 0, []
        while cols_sampled < cols:
            # qr only returns orthogonal Q's as an (N, N) square matrix
            c.append(stddev * torch.linalg.qr(torch.randn(rows, rows, requires_grad=False), mode="complete")[0])
            cols_sampled += rows

        w = torch.cat(c, dim=-1)[:, :cols]
        w = w * np.sqrt(cols)
        return w

    return stddev * torch.randn(rows, cols, requires_grad=False)


class ProtoSNGP(Protonet):
    def __init__(
        self,
        n_layers: int = 6,
        in_dim: int = 2,
        h_dim: int = 64,
        classes: int = 2,
        p: float = 0.01,
        ctype: str = "error",
        gp_in_dim: int = 64,
        forward_type: str = "softmax",
        gp_h_dim: int = 64,
        s: float = 0.001,
        m: float = 0.999,
        **kwargs: Any
    ):
        super().__init__(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, ctype=ctype, forward_type=forward_type, **kwargs)  # type: ignore
        self.gp_in_dim = gp_in_dim
        self.gp_h_dim = gp_h_dim
        self.s = s
        self.initializer_std = 0.05
        self.m = m

        self.register_buffer("rff_weight", random_features(gp_in_dim, gp_h_dim, stddev=self.initializer_std, orthogonal=True))
        self.register_buffer("rff_bias", torch.rand(gp_h_dim, requires_grad=False) * 2 * math.pi)
        self.register_buffer("gp_input_projection", random_features(gp_in_dim, gp_in_dim, stddev=self.initializer_std))
        self.prior_diag = nn.Parameter(torch.ones(1, gp_h_dim), requires_grad=True)

        self.ln = nn.LayerNorm(gp_in_dim)

    def rff(self, x: T) -> T:
        x = x @ self.gp_input_projection
        x = self.ln(x)
        # gp inpout cale: https://github.com/google/uncertainty-baselines/blob/main/baselines/cifar/sngp.py#L134
        # in the GP layer code: https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L131
        x = x * (1 / np.sqrt(2))
        x = x @ self.rff_weight + self.rff_bias
        x = np.sqrt(2 / self.gp_h_dim) * torch.cos(x)
        return x

    def beta(self, s_phi: T, sy: T, q_phi: T, n_way: int, k_shot: int, precisions: T = None) -> T:
        centroids = self.compute_centroids(s_phi, sy)
        if precisions is None:
            precisions, _ = batched_sherman_morrison_rank_one_inverse(
                A_diag=F.softplus(self.prior_diag).repeat(n_way, 1), B_factors=s_phi.view(n_way, k_shot, s_phi.size(-1)))

        d = mahalanobis_distance(q_phi, centroids, precisions)
        return -d

    def base(self, x: T) -> T:
        raise NotImplementedError()

    def features_extract(self, sx: T, qx: T) -> Tuple[T, T]:
        return self.rff(self.base(sx)), self.rff(self.base(qx))

    def forward(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:  # type: ignore
        """during training, there will be no updating of the precision matrix, just normal protonet"""
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        sx_phi, qx_phi = self.features_extract(sx, qx)
        mu = self.beta(sx_phi, sy, qx_phi, n_way, k_shot)
        return mu.log_softmax(dim=-1)

    def compute_posterior(self, s_phi: T, sy: T, q_phi: T, n_way: int, k_shot: int, precisions: T) -> T:
        """
        note that this is essentially the same as the original SNGP except for the fact that we are
        not using the Phi(x) @ beta for the logits. The logits and the variance both come from the mahaalanobis distnace
        """
        centroids = self.compute_centroids(s_phi, sy)
        return -mahalanobis_distance(q_phi, centroids, precisions)

    def get_cov(self, sx: T, sy: T, n_way: int, k_shot: int) -> T:  # type: ignore
        s_phi = self.rff(self.base(sx))

        prior = torch.stack([torch.eye(s_phi.size(-1), device=s_phi.device) for _ in range(n_way)])
        s_phi = s_phi.view(n_way, k_shot, -1)

        cov = prior + torch.bmm(s_phi.transpose(1, 2), s_phi)
        return cov

    def get_logits(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        sx_phi, qx_phi = self.features_extract(sx, qx)
        mu = self.beta(sx_phi, sy, qx_phi, n_way, k_shot)
        return mu

    def inference(self, sx: T, sy: T, qx: T, n_way: int, k_shot: int, inference_style: str = "distance") -> Tuple[T, T, T]:  # type: ignore
        s_phi, q_phi = self.features_extract(sx, qx)
        precisions, _ = batched_sherman_morrison_rank_one_inverse(
            A_diag=F.softplus(self.prior_diag).repeat(n_way, 1), B_factors=s_phi.view(n_way, k_shot, s_phi.size(-1)))

        mu = self.compute_posterior(s_phi, sy, q_phi, n_way, k_shot, precisions)
        mu = self.tmp_layer(mu)

        if inference_style == "distance":
            return mu.softmax(dim=-1), mu.log_softmax(dim=-1), -torch.logsumexp(mu, dim=-1)
        elif inference_style == "softmax-sample":
            samples = Normal(mu, -mu).sample((self.samples,))
            pred, log_pred = softmax_log_softmax_of_sample(samples)
            return pred, log_pred, -torch.logsumexp(-mu, dim=-1)
        else:
            raise NotImplementedError(f"{inference_style=}")


class ProtoSNGPFC(ProtoSNGP, FCMixin):
    def __init__(
        self,
        n_layers: int,
        in_dim: int,
        h_dim: int,
        classes: int,
        p: float = 0.01,
        ctype: str = "error",
        s: float = 0.1,
        spectral: bool = True,
        forward_type: str = "softmax",
        gp_in_dim: int = 64,
        gp_h_dim: int = 64
    ):
        super().__init__(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, ctype=ctype, s=s, spectral=spectral, gp_in_dim=gp_in_dim, gp_h_dim=gp_h_dim, forward_type=forward_type)  # type: ignore
        self.classes = classes
        self.p = p
        self.h_dim = h_dim
        self.name = "SNGPProtoFC"

        self.save_features = False
        self.features: List[T] = []

    def down_project(self, x: T) -> T:
        """
        take the input and project it down to 2 dimensions which should preserve the spatial distance.
        This was made in an attempt to visualize the embedding
        """
        x = self.base(x)
        return x @ torch.randn(x.size(-1), 2, device=x.device)

    def base(self, x: T) -> T:
        self.features = []
        for lyr in self.layers:  # type: ignore
            x = lyr(x)
            if self.save_features:
                self.features.append(x)
        return x  # type: ignore


class ProtoSNGPCNN4(ProtoSNGP, CNN4ResidualSpectralMixin):
    def __init__(
        self,
        dims: DimTuple = (()),
        in_ch: int = 1,
        h_dim: int = 64,
        classes: int = 5,
        p: float = 0.01,
        ctype: str = "error",
        c: float = 3.0,
        forward_type: str = "softmax",
        spectral: bool = True,
        gp_in_dim: int = 64,
        gp_h_dim: int = 64,
        **kwargs: Any
    ):
        super().__init__(dims=dims, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, c=c, spectral=spectral, gp_in_dim=gp_in_dim, gp_h_dim=gp_h_dim, forward_type=forward_type, **kwargs)  # type: ignore
        self.classes = classes
        self.h_dim = h_dim
        self.name = "ProtoSNGPCNN4"

    def base(self, x: T) -> T:
        return self.layers(x)  # type: ignore


def proto_sngp_linear(n_layers: int, in_dim: int, h_dim: int, classes: int, p: float, ctype: str, spectral: bool, forward_type: str, gp_h_dim: int, gp_in_dim: int) -> ProtoSNGPFC:
    return ProtoSNGPFC(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type, gp_h_dim=gp_h_dim, gp_in_dim=gp_in_dim)


def proto_sngp_cnn(dims: DimTuple, in_ch: int, h_dim: int, classes: int, p: float, ctype: str, spectral: bool, forward_type: str, gp_h_dim: int, gp_in_dim: int) -> ProtoSNGPCNN4:
    return ProtoSNGPCNN4(dims=dims, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type, gp_h_dim=gp_h_dim, gp_in_dim=gp_in_dim)
