import torch
import torch.nn as nn
from components import LossFunctions, InferenceNet, GenerationNet


class StackedVAGT(nn.Module):
    def __init__(self, layer_xz=1, layer_h=3, n_head=8, c_dim=20, x_dim=36, z_dim=15, h_dim=20, embd_h=256, embd_s=256, q_len=1,
                 vocab_len=128, win_len=20, dropout=0.1, beta=0.1, anneal_rate=1, max_beta=1, device=torch.device('cpu'),
                 is_train=True):
        super(StackedVAGT, self).__init__()
        self.beta = beta
        self.max_beta = max_beta
        self.anneal_rate = anneal_rate
        self.layer_xz = layer_xz
        self.layer_h = layer_h
        self.n_head = n_head
        self.x_dim = x_dim
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.embd_h = embd_h
        self.embd_s = embd_s
        self.q_len = q_len
        self.vocab_len = vocab_len
        self.win_len = win_len
        self.dropout = dropout
        self.c_dim = c_dim
        self.is_train = is_train

        self.losses = LossFunctions()
        self.inference = InferenceNet(c_dim, z_dim, x_dim, h_dim, embd_h, layer_xz, layer_h, n_head,
                                      vocab_len, dropout, q_len, win_len, device, is_train)
        self.generation = GenerationNet(h_dim, z_dim, x_dim, win_len)

        self.gamma0 = nn.Parameter(torch.randn((embd_s, self.h_dim), requires_grad=True))
        self.register_parameter('gamma0', self.gamma0)
        self.gamma1 = nn.Parameter(torch.randn((embd_s, self.z_dim), requires_grad=True))
        self.register_parameter('gamma1', self.gamma1)
        self.gamma2 = nn.Parameter(torch.randn((embd_s, self.h_dim), requires_grad=True))
        self.register_parameter('gamma2', self.gamma2)
        self.alpha0 = nn.Parameter(torch.randn((self.x_dim, embd_s), requires_grad=True))
        self.register_parameter('alpha0', self.alpha0)
        self.alpha1 = nn.Parameter(torch.randn((self.z_dim, embd_s), requires_grad=True))
        self.register_parameter('alpha1', self.alpha1)

    def loss_LLH(self, x, x_mu, x_logsigma):
        loglikelihood = self.losses.log_normal(x.float(), x_mu.float(), torch.pow(torch.exp(x_logsigma.float()), 2))
        return loglikelihood

    def loss_KL(self, z_mean_posterior_forward, z_logvar_posterior_forward, z_mean_prior_forward,
                z_logvar_prior_forward):
        z_var_posterior_forward = torch.exp(z_logvar_posterior_forward)
        z_var_prior_forward = torch.exp(z_logvar_prior_forward)
        kld_z_forward = 0.5 * torch.sum(z_logvar_prior_forward - z_logvar_posterior_forward +
                                        ((z_var_posterior_forward + torch.pow(
                                            z_mean_posterior_forward - z_mean_prior_forward, 2)) /
                                         z_var_prior_forward) - 1)
        return kld_z_forward

    def forward(self, x):
        z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, h_out = \
            self.inference(x, self.alpha1, self.alpha0)
        z_mean_prior_forward, z_logvar_prior_forward, x_mu, x_logsigma = \
            self.generation(h_out, z_posterior_forward, self.alpha1, self.alpha0, self.gamma0,
                            self.gamma1, self.gamma2)
        return z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, z_mean_prior_forward, \
               z_logvar_prior_forward, x_mu, x_logsigma
