from typing import Any, List, Tuple

import torch
from layers.spectral_norm import SpectralNorm
from mahalanobis.models.mixins import (CNN4ResidualSpectralMixin, DimTuple,
                                       FCMixin)
from mahalanobis.models.proto_mahalanobis import mahalanobis_distance
from mahalanobis.models.protonet import Protonet
from mahalanobis.models.set_xformer import \
    batched_sherman_morrison_rank_one_inverse
from torch import nn
from torch.nn import functional as F
from utils import softmax_log_softmax_of_sample

T = torch.Tensor


__all__ = ["proto_ddu_linear", "proto_ddu_cnn"]


class DDUProto(Protonet):
    def __init__(
        self,
        n_layers: int = 6,
        in_ch: int = 1,
        h_dim: int = 64,
        classes: int = 5,
        p: float = 0.01,
        ctype: str = "error",
        forward_type: str = "softmax",
        in_dim: int = 2,
        cov_dim: int = 64,
        **kwargs: Any
    ):
        super().__init__(in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, forward_type=forward_type, **kwargs)  # type: ignore
        self.prior_diag = nn.Parameter(torch.ones(1, cov_dim), requires_grad=True)
        self.cov_dim = cov_dim

    def init_sigma_lambda(self) -> None:
        """reinitializes sigma and lambda to new values"""
        self.prec = torch.stack([torch.eye(self.cov_dim, self.cov_dim, requires_grad=False, device=self.prec.device) for _ in range(self.classes)])  # type: ignore
        self.prec *= self.s  # type: ignore
        self.cov = torch.zeros((self.classes, self.cov_dim, self.cov_dim), device=self.cov.device, requires_grad=False)  # type: ignore

    def get_logits(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        sx_phi, qx_phi = self.base(sx), self.base(qx)
        centroids = self.compute_centroids(sx_phi, sy)
        precisions, _ = batched_sherman_morrison_rank_one_inverse(A_diag=F.softplus(self.prior_diag).repeat(n_way, 1), B_factors=sx_phi.view(n_way, k_shot, -1))
        return -mahalanobis_distance(qx_phi, centroids, precisions)

    def forward(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        logits = self.get_logits(sx, sy, qx, n_way=n_way, k_shot=k_shot)
        return torch.log_softmax(logits, dim=-1)

    def class_counts(self, y: T, way: int) -> Tuple[T, T]:
        class_counts = torch.zeros(y.size(0), way, device=y.device)
        class_counts[torch.arange(y.size(0)), y] = 1

        cnt = class_counts.sum(dim=0, keepdim=True)
        log_class_prior = torch.log(class_counts.sum(dim=0, keepdim=True) / y.size(0))

        cnt[cnt == 1] += 1e-2   # make these instances one to avoid dividing by zero
        return cnt, log_class_prior

    def compute_covariance(self, phi: T, n_way: int, k_shot: int) -> T:
        phi = phi.view(n_way, k_shot, -1)
        AAT = phi.transpose(1, 2).bmm(phi)
        return AAT

    def get_cov(self, sx: T, ys: T, n_way: int = 5, k_shot: int = 5) -> T:
        sx_phi = self.base(sx)
        diag = torch.stack([torch.eye(sx_phi.size(-1), device=sx_phi.device) for _ in range(n_way)])
        cov = diag + self.compute_covariance(sx_phi, n_way, k_shot)
        return cov

    def log_px(self, phi: T, centroids: T, precisions: T, logdets: T, log_prior: T) -> T:
        # c = (self.cov.size(1) / 2) * np.log(2 * math.pi) + 0.5 * logdets
        mahalanobis = mahalanobis_distance(phi, centroids, precisions)
        return -mahalanobis  # type: ignore

    def inference(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5, inference_style: str = "distance") -> Tuple[T, T, T]:
        sx_phi, qx_phi = self.base(sx), self.base(qx)
        centroids = self.compute_centroids(sx_phi, sy)

        class_cnt, log_class_prior = self.class_counts(sy, n_way)
        precisions, _ = batched_sherman_morrison_rank_one_inverse(A_diag=F.softplus(self.prior_diag).repeat(n_way, 1), B_factors=sx_phi.view(n_way, k_shot, -1))

        if inference_style == "distance":
            d = -mahalanobis_distance(qx_phi, centroids, precisions)
            d = self.tmp_layer(d)
            return d.softmax(dim=-1), d.log_softmax(dim=-1), -torch.logsumexp(-d, dim=1)
        elif inference_style == "softmax-sample":
            d = -mahalanobis_distance(qx_phi, centroids, precisions)
            samples = torch.distributions.Normal(d, -d).sample((self.samples,))
            pred, log_pred = softmax_log_softmax_of_sample(samples)
            return pred, log_pred, -torch.logsumexp(-d, dim=-1)
        else:
            raise NotImplementedError(f"inference style: {inference_style} not implemented")


class ProtoDDUFC(DDUProto, FCMixin):
    def __init__(
        self,
        n_layers: int = 6,
        in_dim: int = 2,
        h_dim: int = 128,
        classes: int = 2,
        p: float = 0.01,
        ctype: str = "error",
        c: float = 1.0,
        forward_type: str = "softmax",
        spectral: bool = True,
        cov_dim: int = 64,
        **kwargs: Any
    ):
        super().__init__(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, ctype=ctype, c=c, spectral=spectral, forward_type=forward_type, cov_dim=cov_dim, **kwargs)  # type: ignore
        self.classes = classes
        self.h_dim = h_dim
        self.cov_dim = cov_dim
        self.save_features = False
        self.features: List[T] = []
        self.name = "ProtoDDUFC"

    def down_project(self, x: T) -> T:
        """
        take the input and project it down to 2 dimensions which should preserve the spatial distance.
        This was made in an attempt to visualize the embedding
        """
        x = self.base(x)
        return x @ torch.randn(x.size(-1), 2, device=x.device)

    def base(self, x: T) -> T:
        self.features = []
        for lyr in self.layers:  # type: ignore
            x = lyr(x)
            if self.save_features:
                self.features.append(x)
        return x  # type: ignore


class ProtoDDUCNN4(DDUProto, CNN4ResidualSpectralMixin):
    def __init__(
        self,
        dims: DimTuple,
        in_ch: int = 1,
        h_dim: int = 64,
        classes: int = 5,
        p: float = 0.01,
        ctype: str = "error",
        forward_type: str = "softmax",
        c: float = 3.0,
        spectral: bool = False,
        cov_dim: int = 64,
        **kwargs: Any
    ):
        super().__init__(dims=dims, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, c=c, forward_type=forward_type, cov_dim=cov_dim, **kwargs)  # type: ignore
        self.classes = classes
        self.h_dim = h_dim
        self.cov_dim = cov_dim
        self.name = "ProtoDDUCNN4"

    def base(self, x: T) -> T:
        return self.layers(x)  # type: ignore


def proto_ddu_linear(n_layers: int, in_dim: int, h_dim: int, classes: int, p: float, ctype: str, spectral: bool, forward_type: str, cov_dim: int) -> ProtoDDUFC:
    return ProtoDDUFC(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type, cov_dim=cov_dim)


def proto_ddu_cnn(dims: DimTuple, in_ch: int, h_dim: int, classes: int, p: float, ctype: str, spectral: bool, forward_type: str, cov_dim: int) -> ProtoDDUCNN4:
    return ProtoDDUCNN4(dims=dims, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type, cov_dim=cov_dim)
