from typing import Any, List

import torch
from torch import nn

T = torch.Tensor


class Protonet(nn.Module):
    classes: int
    layers: Any
    save_features: bool
    features: List[T]
    name: str
    _temperature: nn.Parameter

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)  # type: ignore
        # super is needed to be called here with **kwargs because we use this class in mixins downstream and
        # if the kwargs are not passed through the whole stack then they will be deleted and not make it to
        # all the mixin classes

    def base(self, x: T) -> T:
        raise NotImplementedError()

    def compute_centroids(self, phi: T, y: T) -> T:
        # build mask that looks like a one hot vector of classes (batch, class)
        mask = torch.zeros(y.size(0), self.classes, device=phi.device)
        mask[torch.arange(y.size(0)), y] = 1

        # unsuqeeze and multiply so the mask will zero out unwanted instances (b, classes, dim) and sum over
        # batch and divide by the number of examples for the class --> (classes, dim)
        centroids = (mask.unsqueeze(-1) * phi.unsqueeze(1)).sum(dim=0)
        centroids = centroids / mask.sum(dim=0).unsqueeze(-1)
        return centroids  # type: ignore

    def count_parameters(model) -> int:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def compute_dist(self, phi: T, centroids: T) -> T:
        return -torch.pow(phi.unsqueeze(1) - centroids, 2).sum(dim=-1) / 2

    def forward(self, sx: T, sy: T, qx: T, n_way: int = 5, k_shot: int = 5) -> T:
        sx_phi, qx_phi = self.base(sx), self.base(qx)
        centroids = self.compute_centroids(sx_phi, sy)
        return self.compute_dist(qx_phi, centroids)


class ProtonetCNN4(Protonet):
    def __init__(self, in_ch: int, h_dim: int, z_dim: int, classes: int) -> None:
        super().__init__()

        self.classes = classes
        self.name = "ProtonetCNN4"

        def conv_block(in_ch: int, out_ch: int) -> nn.Module:
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )

        self.layers = nn.Sequential(
            conv_block(in_ch, h_dim),
            conv_block(h_dim, h_dim),
            conv_block(h_dim, h_dim),
            conv_block(h_dim, z_dim),
            nn.Flatten()
        )

    def base(self, x: T) -> T:
        return self.layers(x)  # type: ignore


# TODO: these have not been implemented yet
# class ProtonetResNet12(Protonet):
#     def __init__(self, in_ch: int, classes: int, ctype: str = "error", spectral: bool = False) -> None:
#         super().__init__()
#
#         self.classes = classes
#         self.path = os.path.join("resnet12")
#         self.layers = ResNet12(in_ch, spectral=spectral, ctype=ctype)
#
#     def base(self, x: T) -> T:
#         return self.layers(x).view(x.size(0), -1)  # type: ignore
