import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from ..._layers import GCN_layer
from ..._utils import log_max


class EPM_VAE(nn.Module):
    def __init__(self, ft_in: int, ft_h1: int, ft_h2: int, dropout=0., bias=False, act=nn.Softplus()):
        """
        Inputs:
            ft_in: [int] dimensionality of input features
            ft_h1: [int] dimensionality of 1st hidden layer
            ft_h2: [int] dimensionality of 2nd hidden layer (phi)
        """
        super(EPM_VAE, self).__init__()

        # hyperparameters
        self.ft_in = ft_in
        self.ft_h1 = ft_h1
        self.ft_h2 = ft_h2
        self.SMALL = 1e-16
        self.dropout = dropout
        self.deterministic_encoder = True

        # learnable parameter
        self.r_rtsq = nn.Parameter(torch.ones(1))

        # layers
        self.GCN_layers = nn.ModuleList(
            [GCN_layer(ft_in, ft_h1, 0., act, bias), GCN_layer(ft_h1, ft_h2 + 1, dropout, act, bias)]
        )

    def reparameterize(self, lbd, k):
        random = not self.deterministic_encoder
        if random and self.training:
            # weibull reparameterization: phi = lbd * (- ln(1 - u)) ^ (1/k), u ~ uniform(0,1)
            u = torch.rand_like(lbd)
            phi = lbd * (- log_max(1 - u, self.SMALL)).pow(1 / k)
        else:
            phi = lbd * torch.exp(torch.lgamma(1 + k.pow(-1)))

        return phi

    def encoder(self, batch_adj, batch_fts):
        h = batch_fts
        for i in range(len(self.GCN_layers)):
            h = self.GCN_layers[i](batch_adj, h)
        lbd, k = h.split([self.ft_h2, 1], dim=-1)

        return self.reparameterize(lbd, k + 0.1), (k + 0.1), lbd

    def decoder(self, phi):
        phi = F.dropout(phi, self.dropout, self.training)
        if len(phi.shape) == 2:
            phir = phi.mul(self.r_rtsq.unsqueeze(0))
            preds = 1 - torch.exp(- torch.mm(phir, phir.t()))
        elif len(phi.shape) == 3:
            phir = phi.mul(self.r_rtsq.view(1, 1, -1))
            preds = 1 - torch.exp(- torch.bmm(phir, phir.transpose(1,2)))
        else:
            pass

        return preds


    def forward(self, batch_adj, batch_fts):
        phi, k, lbd = self.encoder(batch_adj, batch_fts)
        return self.decoder(phi), k, lbd


class EPM_VAE_Loss(object):
    def __init__(self, kl_weights: torch.Tensor, alpha=1., beta=1., lite_mode=False, device='cuda:0'):
        """
        Inputs explained:
            log likelihood = reduce_sum(pos_labels_adj * log(preds) + neg_labels_adj * log(1 - preds)),
            kl_weights: kl weights: 1-D tensor, the average weights for each row of KL mat
                        n1 * [1/n1] + n2 * [2/n2] + ... + nB * [1/nB],
            alpha, beta: parameters of phi's prior distribution,
            lite mode: compute _rec_loss individually for each graph in the batch.
        """
        super(EPM_VAE_Loss, self).__init__()
        self.lite_mode = lite_mode
        self.kl_weights = kl_weights
        self.prior_alpha = alpha
        self.prior_beta = beta
        self.device = device
    

    def _get_graph_stats(self, adj_labels_):
        """ count num_nodes_, get pos_weight_ and neg_weight_"""
        num_nodes_ = np.array([adj.shape[0] for adj in adj_labels_])
        num_edges_ = np.array([torch.sparse.sum(adj).item() for adj in adj_labels_])
        num_negas_ = num_nodes_ ** 2 - num_edges_
        zero_subs_ = num_nodes_ ** (-2.)

        pos_weight_ = np.where(num_negas_ > 0, .5 * num_edges_ ** (-1.), zero_subs_)
        neg_weight_ = np.where(num_negas_ > 0, .5 * num_negas_ ** (-1.), zero_subs_)

        return num_nodes_, pos_weight_, neg_weight_ 

    
    def _get_labels(self, adj_labels_):
        """
        Input:
            adj_labels_: a list of sparse binary adjs.
        Outputs:
            in lite-mode, return lists objects 'plabs_' and 'nlabs_', both length = batch_size;
            otherwise, return block diagonal dense matrices 'plabs_adj' and 'nlabs_adj'
        """
        num_nodes_, pos_weight_, neg_weight_ = self._get_graph_stats(adj_labels_)
        # get binary pos-labels and neg-labels
        bin_mask_ = [torch.ones(n, n).to(self.device) for n in num_nodes_]
        plabs_ = [bpl.to_dense() for bpl in adj_labels_]
        nlabs_ = [(bm - bpl) for (bm, bpl) in zip(bin_mask_, plabs_)]
        # get weighted pos-labels and neg-labels
        plabs_ = [(bpl * pw) for (bpl, pw) in zip(plabs_, pos_weight_)]
        nlabs_ = [(bnl * nw) for (bnl, nw) in zip(nlabs_, neg_weight_)]

        if not self.lite_mode:
            # created block diagonal adj labels 'plabs_adj' and 'nlabs_adj'
            plabs_adj = torch.block_diag(*plabs_)
            nlabs_adj = torch.block_diag(*nlabs_)
            return nlabs_adj, plabs_adj
        else:
            # padding
            # max_nodes = num_nodes_.max()
            # num_pads_ = max_nodes - num_nodes_
            # pad_block_ = [torch.zeros(npad, npad).to(self.device) for npad in num_pads_]
            # plabs_ = [torch.block_diag(plb, pad) for (plb, pad) in zip(plabs_, pad_block_)]
            # nlabs_ = [torch.block_diag(nlb, pad) for (nlb, pad) in zip(nlabs_, pad_block_)]
            return nlabs_, plabs_


    
    def _weighted_nll(self, preds, pos_labels, neg_labels):
        """ 
        return weighted negative log-likelihood 
        nll = reduce_sum(pos_labels * log(preds) + neg_labels * log(1 - preds))
        """
        weighted_ll = pos_labels * log_max(preds) + neg_labels * log_max(1 - preds)
        return - weighted_ll.sum()


    def _rec_loss(self, preds_obj, adj_labels_):
        """graph reconstruction loss, weighted negative log-likelihood"""
        nlabs_obj, plabs_obj = self._get_labels(adj_labels_)
        if self.lite_mode:
            nll_ = [self._weighted_nll(preds, plabs, nlabs) for (preds, plabs, nlabs) in zip(preds_obj, plabs_obj, nlabs_obj)]
            return torch.stack(nll_).sum()
        else:
            return self._weighted_nll(preds_obj, plabs_obj, nlabs_obj)


    def _kl_loss(self, k, lbd):
        """kl(weibull(k, lbd) || gamma(alpha, beta))"""
        eulergamma = 0.5772
        alpha = torch.tensor([self.prior_alpha]).to(k.device)
        beta = torch.tensor([self.prior_beta]).to(k.device)

        KL_Part1 = eulergamma * (1 - k.pow(-1)) + log_max(lbd / k) + 1 + alpha * torch.log(beta)
        KL_Part2 = -torch.lgamma(alpha) + (alpha - 1) * (log_max(lbd) - eulergamma * k.pow(-1))
        KL_Part3 = - beta * lbd * torch.exp(torch.lgamma(1 + k.pow(-1)))

        kl_div = - (KL_Part1 + KL_Part2 + KL_Part3) * self.kl_weights.unsqueeze(1)

        return kl_div.sum()


    def __call__(self, preds_obj, adj_labels_, k, lbd):
        return (self._rec_loss(preds_obj, adj_labels_) + 0. * self._kl_loss(k, lbd)) * 0.01







