from typing import Tuple

import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as F

from utils import ConvArgs, conv_bn_relu


class Decoder(nn.Module):
    def __init__(self, x_dim: int, r_dim: int, h_dim: int, y_dim: int) -> None:
        super(Decoder, self).__init__()

        self.x_dim = x_dim
        self.r_dim = r_dim
        self.h_dim = h_dim
        self.y_dim = y_dim

        layers = [
            nn.Linear(x_dim + r_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
        ]

        self.xr_to_hidden = nn.Sequential(*layers)
        self.hidden_to_mu = nn.Linear(h_dim, y_dim)
        self.hidden_to_sigma = nn.Linear(h_dim, y_dim)

    def forward(  # type: ignore
        self, x: torch.Tensor, r: torch.Tensor
    ) -> Tuple[torch.Tensor, ...]:
        batch_size, num_points, _ = x.size()
        # Repeat r, so it can be concatenated with every x. This changes shape
        # from (batch_size, z_dim) to (batch_size, num_points, z_dim)
        r = r.unsqueeze(1).repeat(1, num_points, 1)

        # Flatten x and z to fit with linear layer
        x_flat = x.view(batch_size * num_points, self.x_dim)
        r_flat = r.view(batch_size * num_points, self.r_dim)

        # Input is concatenation of z with every row of x
        input_pairs = torch.cat((x_flat, r_flat), dim=1)

        hidden = self.xr_to_hidden(input_pairs)
        mu = self.hidden_to_mu(hidden)
        pre_sigma = self.hidden_to_sigma(hidden)

        # Reshape output into expected shape
        mu = mu.view(batch_size, num_points, self.y_dim)
        pre_sigma = pre_sigma.view(batch_size, num_points, self.y_dim)

        sigma = 0.1 + 0.9 * F.softplus(pre_sigma)
        return mu, sigma


class DeterministicEncoder(nn.Module):
    def __init__(self, x_dim: int, y_dim: int, h_dim: int, r_dim: int) -> None:
        super(DeterministicEncoder, self).__init__()

        self.x_dim = x_dim
        self.y_dim = y_dim
        self.h_dim = h_dim
        self.r_dim = r_dim

        layers = [nn.Linear(x_dim + y_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, r_dim)]

        self.input_to_hidden = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore
        input_pairs = torch.cat((x, y), dim=1)
        return self.input_to_hidden(input_pairs)


class ConditionalNeuralProcess(nn.Module):
    def __init__(
        self, x_dim: int, y_dim: int, r_dim: int, z_dim: int, h_dim: int
    ) -> None:
        super(ConditionalNeuralProcess, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.r_dim = r_dim
        # z_dim not used in CNP, but adding to keep all NP interfaces the same
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.training: bool

        # Initialize networks
        self.det_encoder = DeterministicEncoder(x_dim, y_dim, h_dim, r_dim)
        self.xr_to_y = Decoder(x_dim, r_dim, h_dim, y_dim)

    def aggregate(self, r_i: torch.Tensor) -> torch.Tensor:
        return torch.mean(r_i, dim=1)

    def xy_to_det_r(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """map x and y to a deterministic r aggregation"""
        batch_size, num_points, _ = x.size()
        # Flatten tensors, as encoder expects one dimensional inputs

        x_flat = x.view(batch_size * num_points, self.x_dim)
        y_flat = y.contiguous().view(batch_size * num_points, self.y_dim)
        # Encode each point into a representation r_i
        r_i_flat = self.det_encoder(x_flat, y_flat)
        # Reshape tensors into batches
        r_i = r_i_flat.view(batch_size, num_points, self.r_dim)
        # Aggregate representations r_i into a single representation r, and return
        return self.aggregate(r_i)

    def forward(  # type: ignore
        self, x_ctx: torch.Tensor, y_ctx: torch.Tensor, x_tgt: torch.Tensor, y_tgt=None
    ) -> Normal:
        # Infer quantities from tensor dimensions
        batch_size, num_ctx, x_dim = x_ctx.size()
        _, num_target, _ = x_tgt.size()
        _, _, y_dim = y_ctx.size()

        r = self.xy_to_det_r(x_ctx, y_ctx)
        y_pred_mu, y_pred_sigma = self.xr_to_y(x_tgt, r)
        p_y_pred = Normal(y_pred_mu, y_pred_sigma)

        return p_y_pred


class ConvConditionalNeuralProcess(nn.Module):
    def __init__(
        self,
        in_ch: int,
        x_dim: Tuple[int, int],
        z_dim: int,
        filters: int,
        conv_args: ConvArgs,
        h_dim: int,
        y_dim: int,
    ) -> None:
        super(ConvConditionalNeuralProcess, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        # z_dim not used in CNP, but adding to keep all NP interfaces the same
        self.h_dim = h_dim
        self.training: bool

        conv_layers = []
        tgt_extractor = []
        for arg in conv_args:
            conv_layers += conv_bn_relu(*arg)
            tgt_extractor += conv_bn_relu(*arg)

        self.encoder = nn.Sequential(*conv_layers)
        self.tgt_extractor = nn.Sequential(*tgt_extractor)

        d = (filters * 5 * z_dim ** 2) + (filters * z_dim ** 2)
        self.decoder = nn.Sequential(
            # nn.Linear(z_dim ** 2 * dim, h_dim), nn.LeakyReLU(), nn.Dropout(p=0.05),
            nn.Linear(d, d),
            nn.ReLU(),
            nn.Dropout(p=0.05),
            nn.Linear(d, h_dim),
            nn.ReLU(),
            nn.Dropout(p=0.05),
        )

        self.out = nn.Linear(h_dim, y_dim)

    def forward(
        self, x_ctx: torch.Tensor, y_ctx: torch.Tensor, x_tgt: torch.Tensor
    ) -> torch.Tensor:
        b, n_tgt, c, x, _ = x_tgt.size()
        _, n_ctx, _, _, _ = x_ctx.size()

        tgt = torch.cat((x_ctx, x_tgt), dim=1)
        tgt = self.tgt_extractor(tgt.view(-1, c, x, x)).view(b, n_tgt + n_ctx, -1)

        r = self.encoder(x_ctx.view(-1, c, x, x)).view(b, x_ctx.size(1), -1)

        ctx = torch.Tensor().to(x_ctx.device)
        for i in torch.unique(y_ctx):
            nz = (y_ctx == i).nonzero(as_tuple=True)
            tmp = r[nz].view(b, x_ctx.size(1) // 5, -1).mean(dim=1)
            ctx = torch.cat((ctx, tmp), dim=1)

        ctx = ctx.unsqueeze(1).repeat(1, n_tgt + n_ctx, 1)
        input_pairs = torch.cat((ctx, tgt), dim=2)
        hidden = self.decoder(input_pairs)
        return self.out(hidden)

    def forward_eval(
        self, x_ctx: torch.Tensor, y_ctx: torch.Tensor, x_tgt: torch.Tensor
    ) -> torch.Tensor:
        "eval loop doesn't bother decoding the context points since we only care about the target labels"
        b, n_tgt, c, x, _ = x_tgt.size()
        _, n_ctx, _, _, _ = x_ctx.size()

        tgt = self.tgt_extractor(x_tgt.view(-1, c, x, x)).view(b, n_tgt, -1)

        r = self.encoder(x_ctx.view(-1, c, x, x)).view(b, x_ctx.size(1), -1)

        ctx = torch.Tensor().to(x_ctx.device)
        for i in torch.unique(y_ctx):
            nz = (y_ctx == i).nonzero(as_tuple=True)
            tmp = r[nz].view(b, x_ctx.size(1) // 5, -1).mean(dim=1)
            ctx = torch.cat((ctx, tmp), dim=1)

        ctx = ctx.unsqueeze(1).repeat(1, n_tgt, 1)
        input_pairs = torch.cat((ctx, tgt), dim=2)
        hidden = self.decoder(input_pairs)
        return self.out(hidden)
