from typing import Any, Tuple

import torch
from layers import TemperatureScaler
from mahalanobis.models.mixins import (CNN4ResidualSpectralMixin, DimTuple,
                                       FCMixin)
from utils import softmax_log_softmax_of_sample

from protonet.model import Protonet as _Protonet

T = torch.Tensor

__all__ = ["protonet_linear", "Protonet", "protonet_cnn"]


class Protonet(_Protonet):
    def __init__(self, forward_type: str = "softmax", **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.tmp_layer = TemperatureScaler()
        self.samples = 10000
        self.tuned = False
        self.forward_type = forward_type

    def temp_scale(self, logits: T) -> T:
        return self.tmp_layer(logits)  # type: ignore

    def get_logits(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        return -self._forward(sx, sy, qx, n_way=n_way, k_shot=k_shot)

    def compute_dist(self, phi: T, centroids: T) -> T:
        # redefining this function here because the original module divides the distance by 2 which
        # the other models do not do in this module
        return torch.pow(phi.unsqueeze(1) - centroids, 2).sum(dim=-1)

    def _forward(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        sx_phi, qx_phi = self.base(sx), self.base(qx)
        centroids = self.compute_centroids(sx_phi, sy)
        return self.compute_dist(qx_phi, centroids)

    def forward(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        if self.forward_type != "softmax":
            raise ValueError("only supports softmax forward")

        d = self._forward(sx, sy, qx, n_way=n_way, k_shot=k_shot)
        return torch.log_softmax(-d, dim=-1)

    def inference(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5, inference_style: str = "distance") -> Tuple[T, T, T]:
        # the original protonet implementation in the module divides the distance by two in the forward
        d = -self._forward(sx, sy, qx, n_way=n_way, k_shot=k_shot)

        if inference_style == "distance":
            d = self.tmp_layer(d)
            return d.softmax(dim=-1), d.log_softmax(dim=-1), -torch.logsumexp(d, dim=-1)
        elif inference_style == "softmax-sample":
            samples = torch.distributions.Normal(d, -d).sample((self.samples,))
            pred, log_pred = softmax_log_softmax_of_sample(samples)
            return pred, log_pred, -torch.logsumexp(d, dim=-1)
        else:
            raise NotImplementedError(f"inference style: {inference_style} not implemented")


class ProtonetFC(Protonet, FCMixin):
    def __init__(self, n_layers: int, in_dim: int, h_dim: int, classes: int, p: float = 0.1, ctype: str = "error", spectral: bool = True, forward_type: str = "softmax") -> None:
        super().__init__(n_layers=n_layers, in_dim=in_dim, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type)
        self.classes = classes
        self.save_features = False
        self.features = []  # type: ignore
        self.name = "ProtonetFC" + ("SN" if spectral else "")

    def down_project(self, x: T) -> T:
        """take the input and project it down to 2 dimensions which should preserve the spatial distance"""
        x = self.base(x)
        return x @ torch.randn(x.size(-1), 2, device=x.device)

    def base(self, x: T) -> T:
        self.features = []
        for lyr in self.layers:
            x = lyr(x)
            if self.save_features:
                self.features.append(x)
        return x  # type: ignore


class ProtonetCNN4(Protonet, CNN4ResidualSpectralMixin):
    def __init__(
        self,
        dims: DimTuple = (()),
        in_ch: int = 1,
        h_dim: int = 64,
        classes: int = 5,
        p: float = 0.1,
        ctype: str = "error",
        spectral: bool = True,
        forward_type: str = "softmax"
    ) -> None:
        super().__init__(dims=dims, n_layers=4, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type)  # type: ignore
        self.classes = classes
        self.name = "ProtonetCNN4" + ("SN" if spectral else "")

    def base(self, x: T) -> T:
        return self.layers(x)  # type: ignore


def protonet_linear(n_layers: int, in_dim: int, h_dim: int, classes: int, p: float, ctype: str, spectral: bool, forward_type: str) -> ProtonetFC:
    return ProtonetFC(n_layers, in_dim, h_dim, classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type)


def protonet_cnn(dims: DimTuple, in_ch: int, h_dim: int, classes: int, p: float, ctype: str, spectral: bool, forward_type: str) -> ProtonetCNN4:
    return ProtonetCNN4(dims=dims, in_ch=in_ch, h_dim=h_dim, classes=classes, p=p, ctype=ctype, spectral=spectral, forward_type=forward_type)
