from typing import Tuple

import torch
from torch import nn
from torch.distributions import Normal

from utils import ConvArgs, conv_bn_relu


class Perturber(nn.Module):
    def __init__(self, in_dim: int, out_dim: int) -> None:
        super(Perturber, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.mu = nn.Sequential(nn.Linear(in_dim, out_dim), nn.LeakyReLU())
        self.log_sigma = nn.Sequential(nn.Linear(in_dim, out_dim), nn.LeakyReLU())

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        n = Normal(self.mu(x), torch.exp(self.log_sigma(x)))
        # print("mu: ", self.mu(x), " sigma: ", torch.exp(self.log_sigma(x)))
        return 1 + n.rsample(), n.entropy()


class ConvConditionalNeuralProcess(nn.Module):
    def __init__(
        self,
        in_ch: int,
        x_dim: Tuple[int, int],
        z_dim: int,
        filters: int,
        conv_args: ConvArgs,
        h_dim: int,
        y_dim: int,
        perturber: Perturber,
    ) -> None:
        super(ConvConditionalNeuralProcess, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        # z_dim not used in CNP, but adding to keep all NP interfaces the same
        self.h_dim = h_dim
        self.perturber = perturber
        self.training: bool

        self.tgt = torch.Tensor()
        self.p_tgt = torch.Tensor()

        conv_layers = []
        tgt_extractor = []
        for arg in conv_args:
            conv_layers += conv_bn_relu(*arg)
            tgt_extractor += conv_bn_relu(*arg)

        self.encoder = nn.Sequential(*conv_layers)
        self.tgt_extractor = nn.Sequential(*tgt_extractor)

        d = (filters * 5 * z_dim ** 2) + (filters * z_dim ** 2)
        self.decoder = nn.Sequential(
            # nn.Linear(z_dim ** 2 * dim, h_dim), nn.LeakyReLU(), nn.Dropout(p=0.05),
            nn.Linear(d, d),
            nn.LeakyReLU(),
            nn.Dropout(p=0.05),
            nn.Linear(d, h_dim),
            nn.LeakyReLU(),
            nn.Dropout(p=0.05),
        )

        self.out = nn.Linear(h_dim, y_dim)

    def forward(
        self, x_ctx: torch.Tensor, y_ctx: torch.Tensor, x_tgt: torch.Tensor
    ) -> torch.Tensor:
        b, n_tgt, c, x, _ = x_tgt.size()
        _, n_ctx, _, _, _ = x_ctx.size()

        tgt = torch.cat((x_ctx, x_tgt), dim=1)
        tgt = self.tgt_extractor(tgt.view(-1, c, x, x)).view(b, n_tgt + n_ctx, -1)

        r = self.encoder(x_ctx.view(-1, c, x, x)).view(b, x_ctx.size(1), -1)

        ctx = torch.Tensor().to(x_ctx.device)
        for i in torch.unique(y_ctx):
            nz = (y_ctx == i).nonzero(as_tuple=True)
            tmp = r[nz].view(b, x_ctx.size(1) // 5, -1).mean(dim=1)
            ctx = torch.cat((ctx, tmp), dim=1)

        ctx = ctx.unsqueeze(1).repeat(1, n_tgt + n_ctx, 1)
        input_pairs = torch.cat((ctx, tgt), dim=2)

        hidden = self.decoder(input_pairs)
        return self.out(hidden)

    def forward_phi(
        self, x_ctx: torch.Tensor, y_ctx: torch.Tensor, x_tgt: torch.Tensor,
    ) -> Tuple[torch.Tensor, ...]:
        b, n_tgt, c, x, _ = x_tgt.size()
        _, n_ctx, _, _, _ = x_ctx.size()

        tgt = torch.cat((x_ctx, x_tgt), dim=1)
        tgt = self.tgt_extractor(tgt.view(-1, c, x, x)).view(b, n_tgt + n_ctx, -1)

        r = self.encoder(x_ctx.view(-1, c, x, x)).view(b, x_ctx.size(1), -1)

        ctx = torch.Tensor().to(x_ctx.device)
        for i in torch.unique(y_ctx):
            nz = (y_ctx == i).nonzero(as_tuple=True)
            tmp = r[nz].view(b, x_ctx.size(1) // 5, -1).mean(dim=1)
            ctx = torch.cat((ctx, tmp), dim=1)

        noise, h = self.perturber(tgt)

        p_tgt = tgt * noise

        d = ((p_tgt.unsqueeze(1) - tgt.unsqueeze(2)) ** 2).sum(dim=3)

        ctx = ctx.unsqueeze(1).repeat(1, n_tgt + n_ctx, 1)
        input_pairs = torch.cat((ctx, p_tgt), dim=2)
        hidden = self.decoder(input_pairs)
        return self.out(hidden), h, d, tgt, p_tgt

    def forward_eval(
        self, x_ctx: torch.Tensor, y_ctx: torch.Tensor, x_tgt: torch.Tensor
    ) -> Tuple[torch.Tensor, ...]:
        "eval loop doesn't bother decoding the context points since we only care about the target labels"
        b, n_tgt, c, x, _ = x_tgt.size()
        _, n_ctx, _, _, _ = x_ctx.size()

        tgt = self.tgt_extractor(x_tgt.view(-1, c, x, x)).view(b, n_tgt, -1)

        r = self.encoder(x_ctx.view(-1, c, x, x)).view(b, x_ctx.size(1), -1)

        ctx = torch.Tensor().to(x_ctx.device)
        for i in torch.unique(y_ctx):
            nz = (y_ctx == i).nonzero(as_tuple=True)
            tmp = r[nz].view(b, x_ctx.size(1) // 5, -1).mean(dim=1)
            ctx = torch.cat((ctx, tmp), dim=1)

        ctx = ctx.unsqueeze(1).repeat(1, n_tgt, 1)
        input_pairs = torch.cat((ctx, tgt), dim=2)
        hidden = self.decoder(input_pairs)
        return self.out(hidden), tgt
