from typing import Optional, Tuple

import torch
import torch.nn as nn

from models.set_encoder import LatentPerturber
from utils import ConvArgs, conv_bn_relu


class MCDropoutHeteroConv(nn.Module):
    def __init__(
        self,
        in_dim: Tuple[int, int],
        h_dim: int,
        p: Tuple[float, ...],
        out_dim: int,
        conv_args: ConvArgs,
    ) -> None:
        super(MCDropoutHeteroConv, self).__init__()
        """
        VGG 13, this model was referenced from the deep ensemble paper which tehn referenced this site
        http://torch.ch/blog/2015/07/30/cifar.html
        """

        self.in_dim = in_dim
        self.h_dim = h_dim
        self.p = p
        self.out_dim = out_dim

        conv_layers = []
        for arg in conv_args:
            conv_layers += conv_bn_relu(*arg)

        self.layers = nn.Sequential(*conv_layers)

        z_dim, dim = in_dim[0], conv_args[-1][1]
        for lyr in conv_args:
            if lyr[-1]:
                z_dim = z_dim // 2

        self.out = nn.Sequential(
            nn.Linear(z_dim ** 2 * dim, h_dim),
            nn.ReLU(),
            nn.Dropout(p=p[0]),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Dropout(p=p[1]),
            nn.Linear(h_dim, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
        x = self.layers(x)
        return self.out(x.view(x.size(0), -1))

    def phi(
        self, x: torch.Tensor, perturber: LatentPerturber, theta: bool = True
    ) -> Tuple[torch.Tensor, ...]:
        x = self.layers(x).view(x.size(0), -1)
        x_prime, h, subsets = perturber(x)

        dist = ((x.unsqueeze(0) - x_prime.unsqueeze(1)) ** 2).sum(dim=2)

        if theta:
            # we only want x_prime to apply to the top layers of the network, as in regression when we
            # only operate on the input space.
            x_prime = x_prime.detach()
            x_prime.requires_grad_(True)

        return self.out(x_prime), h, subsets, dist, x, x_prime

    def mc(self, x: torch.Tensor, samples: int) -> torch.Tensor:
        mus = torch.zeros(samples, x.size(0), self.out_dim, device=x.device)
        for i in range(samples):
            mus[i] = self(x)

        return mus


class MCDropoutHetero(nn.Module):
    def __init__(
        self, x_dim: int, h_dim: int, p: Tuple[float, ...], out_dim: int
    ) -> None:
        super(MCDropoutHetero, self).__init__()

        self.x_dim = x_dim
        self.h_dim = h_dim
        self.p = p
        self.out_dim = out_dim

        self.layers = nn.Sequential(
            nn.Linear(self.x_dim, self.h_dim),
            nn.ReLU(),
            nn.Dropout(p=p[0]),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
            nn.Dropout(p=p[1]),
        )

        self.mu = nn.Linear(h_dim, out_dim)
        self.logvar = nn.Linear(h_dim, out_dim)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  # type: ignore
        h = self.layers(x)
        return self.mu(h), self.logvar(h)

    def mc(self, x: torch.Tensor, samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
        mus = torch.zeros(samples, x.size(0), self.out_dim, device=x.device)
        logvars = torch.zeros(samples, x.size(0), self.out_dim, device=x.device)
        for i in range(samples):
            mus[i], logvars[i] = self(x)

        return mus, logvars
