import math
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np  # type: ignore
import torch
from mahalanobis.models.mixins import (CNN4ResidualSpectralMixin, DimTuple,
                                       FCMixin)
from mahalanobis.models.protonet import Protonet
from mahalanobis.models.set_xformer import DiagCovEncoder, LowRankCovEncoder
from torch import nn
from utils import softmax_log_softmax_of_sample

T = torch.Tensor


def mahalanobis_distance(x: T, centroids: T, prec: T) -> T:
    x_mu = x.unsqueeze(1) - centroids.unsqueeze(0)  # (x - mu)^T
    prec = prec.unsqueeze(0)  # type: ignore
    d = torch.einsum("scij,bci->bcj", prec, x_mu)  # (x - mu)^T @ \Sigma^{-1}
    d = (d @ x_mu.transpose(1, 2))  # (x - mu)^T @ \Sigma^{-1} @ (x - mu)
    d = torch.diagonal(d, dim1=1, dim2=2)

    return d  # type: ignore


class ProtoMahalanobis(Protonet):
    def __init__(self, t: float = 1.0, forward_type: str = "none", **kwargs: Any) -> None:
        super().__init__(forward_type=forward_type, **kwargs)  # type: ignore
        if forward_type == "none":
            raise ValueError("proto Mahalanobis needs to have a forward type set")

        self.classes: int
        self.h_dim: int
        self.phi: nn.Module
        self.beta = kwargs["beta"]
        self.train_samples = 1000

        self.t: T
        self.register_buffer("t", torch.tensor(1, requires_grad=False))

        if self.beta:
            self.out_bias = nn.Parameter(torch.zeros(1), requires_grad=True)

    def set_temperature(self, temp: float) -> None:
        self.t = torch.ones_like(self.t) * temp

    def base(self, x: T) -> T:
        raise NotImplementedError()

    def get_statistics(self, sx_phi: T, n_way: int = 5, k_shot: int = 5, compute_cov: bool = False) -> Tuple[T, T, Optional[T], T]:
        return self.phi(sx_phi, n_way=n_way, k_shot=k_shot, compute_cov=compute_cov)  # type: ignore

    def temp(self) -> T:
        return torch.exp(self.t)  # type: ignore

    def sigma(self, x: T) -> T:
        # return torch.clamp(-torch.logsumexp(-x * self.temp(), dim=-1, keepdim=True) / self.temp(), 1e-8)
        return torch.clamp(-torch.logsumexp(-x, dim=-1, keepdim=True), 1e-8)

    def forward(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        sx_phi, qx_phi = self.base(sx), self.base(qx)
        qx_phi = self.phi.centroid_enc_lyr(qx_phi)  # type: ignore

        centroids, precision, _, logdet = self.get_statistics(sx_phi, n_way=n_way, k_shot=k_shot)

        d = mahalanobis_distance(qx_phi, centroids, precision)
        clog_px = -(d + logdet) / 2

        if self.forward_type == "sigmoid":
            out = torch.sigmoid(clog_px + self.out_bias)
            return out
        elif self.forward_type == "exp":
            raise NotImplementedError("there is a problem with using this inference style with exp because exp needs to be over negative values in order to work with BCE")
            clog_px = torch.logsumexp(clog_px, dim=0) - np.log(1000)
            out = torch.clamp(clog_px, -45, 45).exp()
            return out
        elif self.forward_type == "softmax":
            return clog_px.log_softmax(dim=-1)
        else:
            raise ValueError(f"forward type: {self.forward_type} has not been implemented")

    def log_px(self, x: T, centroids: T, lambd: T, logdet: T, log_prior: T, normalized: bool = True) -> Tuple[T, T, T]:
        d = mahalanobis_distance(x, centroids, lambd)
        if normalized:
            conditional_log_probs = log_prior - (0.5 * d) - (0.5 * logdet) - (self.h_dim / 2) * np.log(2 * math.pi)
            marginal_log_probs = torch.logsumexp(conditional_log_probs, dim=-1)
            return conditional_log_probs, marginal_log_probs, d  # type: ignore

        out = -(d + logdet) / 2
        return out, torch.logsumexp(out, dim=-1), d

    def log_class_prior(self, y: T, way: int) -> T:
        log_class_prior = torch.zeros(y.size(0), way, device=y.device)
        log_class_prior[torch.arange(y.size(0)), y] = 1
        log_class_prior = torch.log(log_class_prior.sum(dim=0, keepdim=True) / y.size(0))
        return log_class_prior

    def get_cov(self, sx: T, sy: T, n_way: int = 5, k_shot: int = 5) -> T:
        sx_phi = self.base(sx)
        centroids, _, covariance, _ = self.get_statistics(sx_phi, n_way=n_way, k_shot=k_shot, compute_cov=True)
        assert covariance is not None
        return covariance

    def get_logits(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> Tuple[T, T]:  # type: ignore
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        sx_phi, qx_phi = self.base(sx), self.base(qx)
        qx_phi = self.phi.centroid_enc_lyr(qx_phi)  # type: ignore

        centroids, precision, _, logdet = self.get_statistics(sx_phi, n_way=n_way, k_shot=k_shot)

        d = mahalanobis_distance(qx_phi, centroids, precision)
        clog_px = -(d + logdet) / 2
        return clog_px, d

    def inference(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5, inference_style: str = "distance") -> Tuple[T, T, T]:
        sx_phi, qx_phi = self.base(sx), self.base(qx)
        qx_phi = self.phi.centroid_enc_lyr(qx_phi)  # type: ignore
        centroids, precision, _, logdet = self.get_statistics(sx_phi, n_way=n_way, k_shot=k_shot, compute_cov=False)

        d = mahalanobis_distance(qx_phi, centroids, precision)
        mu = -(d + logdet) / 2
        # energy = self.t * self.sigma(d / self.t)
        # energy = self.sigma(d / self.t)
        energy = self.sigma(d) / self.t

        if inference_style == "distance":
            return mu.softmax(dim=-1), mu.log_softmax(dim=-1), -torch.logsumexp(-d, dim=-1)
        elif inference_style == "softmax-sample":
            samples = torch.distributions.Normal(mu, energy).sample((self.samples,))
            pred, log_pred = softmax_log_softmax_of_sample(samples)
            return pred, log_pred, energy[:, 0].squeeze(-1)
        else:
            raise NotImplementedError(f"{inference_style=}")


class ProtoMahalanobisFC(ProtoMahalanobis, FCMixin):
    def __init__(
        self,
        n_layers: int = 6,
        in_dim: int = 2,
        h_dim: int = 64,
        classes: int = 5,
        p: float = 0.1,
        c: float = 1.0,
        ctype: str = "error",
        forward_type: str = "none",
        spectral: bool = True,
        encoder: str = "low-rank",
        rank: int = 1,
        t: float = 1.0,
        num_heads: int = 16,
        ln: bool = False,
        beta: bool = False,  # add a bias into the sigmoid function
        pma_type: str = "no-residual"
    ):
        super().__init__(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, t=t, ctype=ctype, c=c, spectral=spectral, beta=beta, forward_type=forward_type)
        self.classes = classes
        self.h_dim = h_dim

        enc_spectral = False
        encoders: Dict[str, Callable] = {
            "diag": partial(
                DiagCovEncoder,
                dim_input=h_dim,
                num_outputs=1,
                dim_hidden=h_dim,
                num_heads=num_heads,
                dim_output=h_dim,
                ln=ln,
                spectral=enc_spectral,
                c=c,
                p=p,
                pma_type=pma_type
            ),
            "low-rank": partial(
                LowRankCovEncoder,
                dim_input=h_dim,
                num_outputs=rank,
                dim_output=h_dim,
                dim_hidden=h_dim,
                num_heads=num_heads,
                ln=ln,
                spectral=enc_spectral,
                c=c,
                p=p,
                pma_type=pma_type
            )
        }
        self.phi = encoders[encoder]()

        self.path = ""
        self.name = "ProtoMahalanobisFC"

        name_append = {"diag": "Diag", "low-rank": f"Rank-{rank}"}
        self.name += name_append[encoder]

        self.save_features = False
        self.features: List[T] = []

    def down_project(self, x: T) -> T:
        """
        take the input and project it down to 2 dimensions which should preserve the spatial distance.
        This was made in an attempt to visualize the embedding
        """
        x = self.base(x)
        return x @ torch.randn(x.size(-1), 2, device=x.device)

    def base(self, x: T) -> T:
        self.features = []
        for lyr in self.layers:  # type: ignore
            x = lyr(x)
            if self.save_features:
                self.features.append(x)
        return x  # type: ignore


class ProtoMahalanobisCNN4(ProtoMahalanobis, CNN4ResidualSpectralMixin):
    def __init__(
        self,
        dims: DimTuple = (()),
        in_ch: int = 1,
        h_dim: int = 64,
        classes: int = 5,
        p: float = 0.1,
        ctype: str = "error",
        forward_type: str = "none",
        spectral: bool = True,
        c: float = 3.0,
        encoder: str = "low-rank",
        rank: int = 8,
        t: float = 1.0,
        num_heads: int = 16,
        ln: bool = False,
        beta: bool = False,  # add a bias into the sigmoid function
        pma_type: str = "no-residual"
    ):
        super().__init__(dims=dims, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, t=t, ctype=ctype, c=c, spectral=spectral, beta=beta, forward_type=forward_type)
        self.classes = classes
        self.h_dim = h_dim
        self.rank = rank
        self.encoder = encoder

        setxformer_in = 1600 if in_ch == 3 else h_dim
        h_dim = setxformer_in

        xformer_c = c
        enc_spectral = False
        encoders: Dict[str, Callable] = {
            "diag": partial(
                DiagCovEncoder,
                dim_input=setxformer_in,
                num_outputs=1,
                dim_hidden=h_dim,
                num_heads=num_heads,
                dim_output=h_dim,
                ln=ln,
                spectral=enc_spectral,
                c=xformer_c,
                p=p, pma_type=pma_type
            ),
            "low-rank": partial(
                LowRankCovEncoder,
                dim_input=setxformer_in,
                num_outputs=rank,
                dim_output=h_dim,
                dim_hidden=h_dim,
                num_heads=num_heads,
                ln=ln,
                spectral=enc_spectral,
                c=xformer_c,
                p=p,
                pma_type=pma_type
            )
        }
        self.phi = encoders[encoder]()

        self.path = ""
        self.name = "ProtoMahalanobisCNN4"

        name_append = {"diag": "Diag", "low-rank": f"Rank-{rank}"}
        self.name += name_append[encoder]

        self.save_features = False
        self.features: List[T] = []

    def base(self, x: T) -> T:
        x = self.layers(x)
        return x


def proto_mahalanobis_linear(
    n_layers: int,
    in_dim: int,
    h_dim: int,
    classes: int,
    p: float,
    ctype: str,
    spectral: bool,
    encoder: str,
    rank: int,
    beta: bool,
    forward_type: str,
    pma_type: str,
    t: float
) -> ProtoMahalanobisFC:
    return ProtoMahalanobisFC(
        n_layers=n_layers,
        in_dim=in_dim,
        h_dim=h_dim,
        classes=classes,
        p=p,
        t=t,
        ctype=ctype,
        spectral=spectral,
        encoder=encoder,
        rank=rank,
        beta=beta,
        forward_type=forward_type,
        pma_type=pma_type
    )


def proto_mahalanobis_cnn(
    dims: DimTuple,
    in_ch: int,
    h_dim: int,
    classes: int,
    p: float,
    ctype: str,
    spectral: bool,
    encoder: str,
    rank: int,
    beta: bool,
    forward_type: str,
    pma_type: str,
    t: float
) -> ProtoMahalanobisCNN4:
    return ProtoMahalanobisCNN4(
        dims=dims,
        in_ch=in_ch,
        h_dim=h_dim,
        classes=classes,
        p=p,
        t=t,
        ctype=ctype,
        spectral=spectral,
        encoder=encoder,
        rank=rank,
        beta=beta,
        forward_type=forward_type,
        pma_type=pma_type
    )
