from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from args import ModelArguments
from .model_utils import logit
from .sampling import (
    reparameterize_beta,
    reparameterize_bernoulli,
    reparameterize_normal,
    kl_beta,
    kl_bernoulli,
    kl_normal,
)


class LFRMLatentEncoder(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int) -> None:
        super().__init__()
        self.linear_pi_logit = nn.Linear(hidden_dim, latent_dim)
        self.linear_w_mean = nn.Linear(hidden_dim, latent_dim)
        self.linear_w_logstd = nn.Linear(hidden_dim, latent_dim)

    def forward(self, H: torch.FloatTensor) -> Tuple[torch.FloatTensor]:
        pi_logit = self.linear_pi_logit(H)
        w_mean = self.linear_w_mean(H)
        w_logstd = self.linear_w_logstd(H)
        return pi_logit, w_mean, w_logstd


class LFRMVAE(nn.Module):
    def __init__(self, args: ModelArguments) -> None:
        super().__init__()
        hidden_dims = [int(h) for h in args.hidden_dims.split("_")]
        self.dropout = args.dropout
        self.temp_prior = args.temp_prior
        self.temp_post = args.temp_post
        self.register_buffer("beta_a_prior_query", torch.FloatTensor([args.beta_a_query]))
        self.register_buffer("beta_a_prior_entity", torch.FloatTensor([args.beta_a_entity]))
        self.register_buffer("beta_b_prior", torch.FloatTensor([args.beta_b]))
        self.register_buffer("normal_mean_prior", torch.FloatTensor([0.0]))
        self.register_buffer("normal_std_prior", torch.FloatTensor([1.0]))
        self.kl_weight = args.kl_weight
        self.recon_weight = args.recon_weight

        # build model parameters
        # a and b for beta prior
        beta_a_query = np.log(np.exp(args.beta_a_query) - 1)  # inverse softplus
        beta_a_query = beta_a_query + torch.zeros(args.latent_dim)
        beta_a_entity = np.log(np.exp(args.beta_a_entity) - 1)
        beta_a_entity = beta_a_entity + torch.zeros(args.latent_dim)
        beta_b = np.log(np.exp(args.beta_b) - 1)
        beta_b = beta_b + torch.zeros(args.latent_dim)
        # shape: (K, )
        self.beta_a_query = nn.Parameter(beta_a_query)
        self.beta_a_entity = nn.Parameter(beta_a_entity)
        self.beta_b = nn.Parameter(beta_b)
        # MLP encoder
        self.encoder_feature = nn.Sequential(
            *[
                nn.Sequential(
                    nn.Linear(dim_in, dim_out),
                    nn.LeakyReLU(),
                    nn.Dropout(self.dropout),
                )
                for dim_in, dim_out in zip(hidden_dims[:-1], hidden_dims[1:])
            ]
        )
        self.encoder_latent = LFRMLatentEncoder(hidden_dims[-1], args.latent_dim)

        # MLP decoder
        hidden_dims.reverse()
        self.decoder_latent = nn.Sequential(
            nn.Linear(args.latent_dim, hidden_dims[0]),
            nn.LeakyReLU(),
            nn.Dropout(self.dropout),
        )

        self.decoder_feature = nn.Sequential(
            *[
                nn.Sequential(
                    nn.Linear(dim_in, dim_out),
                    nn.LeakyReLU(),
                    nn.Dropout(self.dropout),
                )
                for dim_in, dim_out in zip(hidden_dims[1:-2], hidden_dims[2:-1])
            ]
        )
        self.decoder_feature.append(nn.Linear(hidden_dims[-2], hidden_dims[-1]))

    def forward(self, X: torch.FloatTensor, is_query: bool = True) -> Dict[str, torch.Tensor]:
        beta_a, beta_b, pi_logit, w_mean, w_logstd = self.encode(X, is_query)
        # all latent variables are in shape (N, K)
        pi_logit_prior, pi_logit_post, z, y_sample, w = self.reparameterize(
            beta_a, beta_b, pi_logit, w_mean, w_logstd
        )
        # round to 0-1 vector during evaluation
        if not self.training:
            z = torch.round(z)
        X_hat = self.decode(z, w)

        return {
            "X": X,
            "X_recon": X_hat,
            "z": z,
            "w": w,
            "kl_dict": {
                "beta_a": beta_a,
                "beta_b": beta_b,
                "pi_logit_prior": pi_logit_prior,
                "pi_logit_post": pi_logit_post,
                "y_sample": y_sample,
                "w_mean": w_mean,
                "w_logstd": w_logstd,
            },
        }

    def encode(self, X: torch.FloatTensor, is_query: bool = True) -> Tuple[torch.FloatTensor]:
        h = self.encoder_feature(X)
        pi_logit, w_mean, w_logstd = self.encoder_latent(h)
        beta_a = self.beta_a_query if is_query else self.beta_a_entity
        beta_a = F.softplus(beta_a).expand(*pi_logit.size())
        beta_b = F.softplus(self.beta_b).expand(*pi_logit.size())
        return beta_a, beta_b, pi_logit, w_mean, w_logstd

    def reparameterize(
        self,
        beta_a: torch.FloatTensor,
        beta_b: torch.FloatTensor,
        pi_logit: torch.FloatTensor,
        w_mean: torch.FloatTensor,
        w_logstd: torch.FloatTensor,
        eps: float = 1e-7,
    ) -> Tuple[torch.FloatTensor]:
        v = reparameterize_beta(beta_a, beta_b)
        pi_log_prior = torch.cumsum(torch.log(v + eps), dim=1)
        pi_logit_prior = logit(torch.exp(pi_log_prior), eps)
        pi_logit_post = pi_logit

        y_sample = reparameterize_bernoulli(pi_logit_post, self.temp_post, eps)
        z = F.sigmoid(y_sample)
        w = reparameterize_normal(w_mean, w_logstd)

        return pi_logit_prior, pi_logit_post, z, y_sample, w

    def decode(self, z: torch.FloatTensor, w: torch.FloatTensor) -> Tuple[torch.FloatTensor]:
        latent = z * w
        h = self.decoder_latent(latent)
        return self.decoder_feature(h)

    def compute_loss(
        self,
        X: torch.FloatTensor,
        X_recon: torch.FloatTensor,
        beta_a: torch.FloatTensor,
        beta_b: torch.FloatTensor,
        y_sample: torch.FloatTensor,
        pi_logit_prior: torch.FloatTensor,
        pi_logit_post: torch.FloatTensor,
        w_mean: torch.FloatTensor,
        w_logstd: torch.FloatTensor,
        is_query: bool = True,
    ) -> Dict[str, torch.FloatTensor]:
        loss_kl = self.kl_weight * self.compute_kl_loss(
            beta_a,
            beta_b,
            y_sample,
            pi_logit_prior,
            pi_logit_post,
            w_mean,
            w_logstd,
            is_query,
        )
        X, X_recon = F.normalize(X, dim=1), F.normalize(X_recon, dim=1)
        loss_recon = -self.recon_weight * (X * X_recon).sum(dim=1).mean()

        return {
            "loss": loss_kl + loss_recon,
            "loss_kl": loss_kl.detach(),
            "loss_recon": loss_recon.detach(),
        }

    def compute_kl_loss(
        self,
        beta_a_post: torch.FloatTensor,
        beta_b_post: torch.FloatTensor,
        y_sample: torch.FloatTensor,
        pi_logit_prior: torch.FloatTensor,
        pi_logit_post: torch.FloatTensor,
        w_mean_post: torch.FloatTensor,
        w_logstd_post: torch.FloatTensor,
        is_query: bool = True,
    ) -> torch.FloatTensor:
        beta_a_prior = self.beta_a_prior_query if is_query else self.beta_a_prior_entity
        return (
            kl_beta(
                beta_a_prior,
                self.beta_b_prior,
                beta_a_post,
                beta_b_post,
            )
            + kl_bernoulli(
                pi_logit_prior,
                pi_logit_post,
                y_sample,
                self.temp_prior,
                self.temp_post,
            )
            + kl_normal(
                self.normal_mean_prior,
                self.normal_std_prior,
                w_mean_post,
                w_logstd_post.exp(),
            )
        )
