"""
Code build from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html
"""
import numpy as np
import torch
import torch.nn as nn

import models._density_estimators as estims
from utils.constants import Cte
from utils.likelihoods import get_likelihood
from utils.probabilistic_model import ProbabilisticModelSCM


class HVCAUSEModule(nn.Module):
    '''
    Heterogeneous VCAUSE
    '''

    def __init__(self, x_dim_list,
                 # [[2], [3,4], [3,4,5]], e.g. [node_1, node_2, node_3] Should be in the same order as in the dataset
                 h_dim_list_dec,
                 h_dim_list_enc,
                 z_dim,
                 m_layers,  # Number of layers for the message MLP of the decoder
                 lambda_kld=1.0,
                 deg=None,
                 edge_dim=None,
                 num_nodes=None,
                 residual=0,  # Use resitual network in message passing
                 drop_rate=0.0,
                 drop_rate_i=0.0,
                 act_name=Cte.RELU,
                 distr_x_list=None,  # Should be the same size as x_dim_list
                 distr_z='normal',
                 architecture=None,
                 norm_categorical=False,
                 K=1):
        super(HVCAUSEModule, self).__init__()

        assert distr_x_list is not None

        self.K = K

        likelihood_z = get_likelihood(distr_z, z_dim)

        likelihood_x = ProbabilisticModelSCM(x_dim_list=x_dim_list,
                                             distr_x_list=distr_x_list,
                                             lambda_kld=lambda_kld,
                                             norm_categorical=norm_categorical)

        # Instantiate Encoder embedding

        node_dim_max = max(likelihood_x.node_dim_list)

        self._encoder_embeddings = nn.ModuleList()
        for x_dim_i_list in x_dim_list:
            x_dim_i = np.sum(x_dim_i_list)
            embed_i = nn.Linear(x_dim_i, node_dim_max, bias=False)
            self._encoder_embeddings.append(embed_i)
        c_list = [node_dim_max]
        c_list.extend(h_dim_list_enc)
        c_list.append(likelihood_z.params_size)
        # Instantiate Encoder module
        if architecture == 'pna':
            from modules.pna import PNAModule

            self.encoder_module = PNAModule(c_list=c_list,
                                            deg=deg,
                                            edge_dim=edge_dim,
                                            drop_rate=drop_rate,
                                            act_name=act_name,
                                            aggregators=None,
                                            scalers=None,
                                            residual=residual)
        elif architecture == 'dgnn':  # Disjoint GNN
            from modules.disjoint_gnn import DisjointGNN
            self.encoder_module = DisjointGNN(c_list=c_list,
                                              m_layers=len(c_list) - 1,  # We can only have 1 message passing step
                                              edge_dim=edge_dim,
                                              num_nodes=num_nodes,
                                              drop_rate=drop_rate,
                                              drop_rate_i=drop_rate_i,
                                              residual=residual,
                                              act_name=act_name,
                                              aggr='add')

        elif architecture == 'dpna':  # Disjoint PNA
            from modules.disjoint_pna import DisjointPNA
            self.encoder_module = DisjointPNA(c_list=c_list,
                                              m_layers=len(c_list) - 1,  # We can only have 1 message passing step
                                              edge_dim=edge_dim,
                                              deg=deg,
                                              num_nodes=num_nodes,
                                              aggregators=None,
                                              scalers=None,
                                              drop_rate=drop_rate,
                                              act_name=act_name,
                                              residual=residual)
        else:
            raise NotImplementedError

        c_list = [z_dim]
        c_list.extend(h_dim_list_dec)
        c_list.append(likelihood_x.embedding_size)
        # Instantiate Decoder module
        if architecture == 'pna':
            from modules.pna import PNAModule
            self.decoder_module = PNAModule(c_list=c_list,
                                            deg=deg,
                                            edge_dim=edge_dim,
                                            drop_rate=drop_rate,
                                            act_name=act_name,
                                            aggregators=None,
                                            scalers=None,
                                            residual=residual)
        elif architecture == 'dgnn':
            from modules.disjoint_gnn import DisjointGNN

            self.decoder_module = DisjointGNN(c_list=c_list,
                                              m_layers=m_layers,
                                              edge_dim=edge_dim,
                                              num_nodes=num_nodes,
                                              drop_rate=drop_rate,
                                              drop_rate_i=drop_rate_i,
                                              residual=residual,
                                              act_name=act_name,
                                              aggr='add')

        elif architecture == 'dpna':
            from modules.disjoint_pna import DisjointPNA
            self.decoder_module = DisjointPNA(c_list=c_list,
                                              m_layers=m_layers,  # We can only have 1 message passing step
                                              edge_dim=edge_dim,
                                              deg=deg,
                                              num_nodes=num_nodes,
                                              aggregators=None,
                                              scalers=None,
                                              drop_rate=drop_rate,
                                              act_name=act_name,
                                              residual=residual)
        else:
            raise NotImplementedError

        # Instantiate Decoder embedding

        self.z_dim = z_dim

        self.num_nodes = len(x_dim_list)

        self.x0_size = self.num_nodes * node_dim_max

        self.node_dim_max = node_dim_max

        self.x_dim_list = x_dim_list

        self.likelihood_z = likelihood_z
        self.likelihood_x = likelihood_x

        self.distr_z = distr_z
        self.set_z_prior_distr()

    def encoder_params(self):
        params = list(self.encoder_module.parameters()) + list(self._encoder_embeddings.parameters())
        return params

    def decoder_params(self):
        params = list(self.decoder_module.parameters()) + list(self.likelihood_x.parameters())
        return params

    def set_z_prior_distr(self):
        if self.distr_z == Cte.CONTINOUS_BERN:  # Continous Bernoulli
            self.z_prior_distr = torch.distributions.ContinuousBernoulli(
                probs=0.5 * torch.ones(self.hparams.latent_dim))
        elif self.distr_z == Cte.EXPONENTIAL:  # Exponential
            self.z_prior_distr = torch.distributions.Exponential(
                rate=0.2 * torch.ones(self.hparams.latent_dim))
        elif self.distr_z == Cte.BETA:  # Beta
            self.z_prior_distr = torch.distributions.Beta(
                concentration0=torch.ones(self.hparams.latent_dim),
                concentration1=torch.ones(self.hparams.latent_dim))
        elif self.distr_z == Cte.GAUSSIAN:
            self.z_prior_distr = torch.distributions.Normal(torch.zeros(self.z_dim),
                                                            torch.ones(self.z_dim))
        else:
            raise NotImplementedError

    def encoder_embeddings(self, X):

        X_0 = X.view(-1, self.x0_size)

        embeddings = []
        for i, embed_i in enumerate(self._encoder_embeddings):
            X_0_i = X_0[:, (i * self.node_dim_max):((i + 1) * self.node_dim_max)]
            H_i = embed_i(X_0_i[:, :self.likelihood_x.node_dim_list[i]])
            embeddings.append(H_i)

        return torch.cat(embeddings, dim=1).view(-1, self.node_dim_max)

    def encoder(self, X, edge_index, edge_attr=None, return_mean=False, **kwargs):
        logits = self.encoder_module(self.encoder_embeddings(X),
                                     edge_index,
                                     edge_attr=edge_attr, **kwargs)
        if return_mean:
            mean, qz_x = self.likelihood_z(logits, return_mean=True)
            return mean, qz_x
        else:
            qz_x = self.likelihood_z(logits)
            return qz_x

    def decoder(self, Z, edge_index, edge_attr=None, return_type=None, **kwargs):
        logits = self.decoder_module(Z, edge_index, edge_attr, **kwargs)

        if return_type == 'mean':
            mean, px_z = self.likelihood_x(logits, return_mean=True)
            return mean, px_z
        elif return_type == 'sample':
            mean, px_z = self.likelihood_x(logits, return_mean=True)
            return px_z.sample(), px_z
        else:
            px_z = self.likelihood_x(logits)
            return px_z

    def compute_log_w(self, data, K, mask=None):
        '''
        IWAE:  log(1\K \sum_k w_k) w_k = p(x, z_i)/ q(z_i | x)
            log_wi = log  p(x, z_i) - log q(z_i | x)
            :param x:
            :param K:
            :return:
        '''

        x_input = data.x.clone()

        assert mask is None

        log_w = []
        for k in range(K):
            qz_x = self.encoder(data.x, data.edge_index, edge_attr=data.edge_attr, node_ids=data.node_ids)
            z = qz_x.rsample()

            px_z_k = self.decoder(z, data.edge_index, edge_attr=data.edge_attr, node_ids=data.node_ids)

            log_prob_qz_x = qz_x.log_prob(z).view(data.num_graphs, -1).sum(-1)  # Summing over dim(z)*num_nodes
            log_prob_pz = self.z_prior_distr.log_prob(z).view(data.num_graphs, -1).sum(-1)
            log_prob_px_z = px_z_k.log_prob(self.get_x_graph(data, 'x')).sum(-1)

            log_w_k = log_prob_px_z + log_prob_pz - log_prob_qz_x

            log_w.append(log_w_k)

        log_w = torch.stack(log_w, dim=0)

        # [K, N]
        return log_w.T

    def get_x_graph(self, data, attr):
        x = getattr(data, attr)
        mask = data.mask.view(data.num_graphs, -1)[0]
        return x.view(data.num_graphs, -1)[:, mask]

    def forward(self, data, estimator, beta=1.0, dropout_input_rate=0.0):

        x_input = data.x.clone()

        mask = None

        assert dropout_input_rate == 0.0

        if dropout_input_rate > 0:
            num_nodes = data.x.shape[0]
            mask = torch.rand(num_nodes) > dropout_input_rate  # 1 if we use the sample
            x_input[~mask] = 0.0

        if estimator == 'elbo':

            qz_x = self.encoder(x_input,
                                data.edge_index,
                                edge_attr=data.edge_attr,
                                node_ids=data.node_ids)
            z = qz_x.rsample()

            px_z = self.decoder(z, data.edge_index, edge_attr=data.edge_attr, node_ids=data.node_ids)

            if dropout_input_rate > 0:
                log_prob_x = px_z.log_prob(data.x)[mask].flatten(1).sum(1).mean()
                kl_z = torch.distributions.kl.kl_divergence(qz_x, self.z_prior_distr)[mask].flatten(1).sum(1).mean()
            else:
                log_prob_x = px_z.log_prob(self.get_x_graph(data, 'x')).sum(1).mean()
                kl_z = torch.distributions.kl.kl_divergence(qz_x, self.z_prior_distr).view(data.num_graphs, -1).sum(
                    1).mean()

            elbo = log_prob_x - beta * kl_z

            data = {'log_prob_x': log_prob_x,
                    'kl_z': kl_z}

            return elbo, data
        elif estimator == 'iwae':
            log_w = self.compute_log_w(data=data, K=self.K, mask=mask)
            objective, _ = estims.IWAE(log_w, trick=True)
            return objective.mean(), {}

        elif estimator == 'iwaedreg':
            log_w, zs = self.compute_log_w_dreg(data=data, K=self.K)
            objective, _ = estims.IWAE_dreg(log_w, zs)
            return objective.mean(), {}

        else:
            raise NotImplementedError

    @torch.no_grad()
    def reconstruct(self, data, use_mean_encoder=True):
        z_mean, qz_x = self.encoder(data.x, data.edge_index, edge_attr=data.edge_attr,
                                    return_mean=True, node_ids=data.node_ids)

        z = z_mean if use_mean_encoder else qz_x.rsample()
        x_hat, _ = self.decoder(z, data.edge_index, edge_attr=data.edge_attr,
                                    return_type='mean', node_ids=data.node_ids)

        # Shape of x_hat: [num_graphs, total_dim_nodes]
        return z_mean, x_hat
