from typing import Optional, Tuple

import torch
from torch import nn

from models.set_encoder import LatentPerturber
from utils import ConvArgs, conv_bn_relu


class ModelConv(nn.Module):
    def __init__(
        self, in_dim: Tuple[int, int], h_dim: int, y_dim: int, conv_args: ConvArgs,
    ) -> None:
        super(ModelConv, self).__init__()
        """
        VGG 13, this model was referenced from the deep ensemble paper which tehn referenced this site
        http://torch.ch/blog/2015/07/30/cifar.html
        """

        self.in_dim = in_dim
        self.h_dim = h_dim
        self.y_dim = y_dim

        conv_layers = []
        for arg in conv_args:
            conv_layers += conv_bn_relu(*arg)

        self.layers = nn.Sequential(*conv_layers)

        self.out = nn.Sequential(
            nn.Linear(conv_args[-1][1], h_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(h_dim, y_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layers(x)
        return self.out(x.view(x.size(0), -1))  # type: ignore

    def phi(
        self, x: torch.Tensor, perturber: LatentPerturber, theta: bool = True
    ) -> Tuple[torch.Tensor, ...]:
        x = self.layers(x).view(x.size(0), -1)
        x_prime, h, subsets = perturber(x)

        dist = ((x.unsqueeze(0) - x_prime.unsqueeze(1)) ** 2).sum(dim=2)

        if theta:
            # we only want x_prime to apply to the top layers of the network,
            # as in regression when we only operate on the input space.
            x_prime = x_prime.detach()
            x_prime.requires_grad_(True)

        x_prime = self.out(x_prime)
        return x_prime, h, subsets, dist


class Model(nn.Module):
    def __init__(self, x_dim: int, h_dim: int, y_dim: int) -> None:
        super(Model, self).__init__()

        self.x_dim = x_dim
        self.h_dim = h_dim
        self.y_dim = y_dim

        self.layers = nn.Sequential(
            nn.Linear(x_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, h_dim), nn.ReLU()
        )

        self.mu = nn.Linear(h_dim, y_dim)
        self.logvar = nn.Linear(h_dim, y_dim)

    def forward(  # type: ignore
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.layers(x)
        return self.mu(x).squeeze(1), self.logvar(x).squeeze(1)
