"""
Code build from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html
"""
import os

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.preprocessing import FunctionTransformer

import models._density_estimators as estims
from models._evaluator import MyEvaluator
from utils.args_parser import mkdir
from utils.constants import Cte
from utils.dropout import dropout_adj, dropout_adj_parents
from utils.metrics.mmd import MMDLoss
from utils.metrics.time import Timer
from utils.optimizers import get_optimizer, get_scheduler
import json

class VCAUSE(pl.LightningModule):
    '''
    Feature Variational Graph Autoencoder
    '''

    def __init__(self, x_dim,
                 h_dim_list_dec,
                 h_dim_list_enc,
                 z_dim,
                 m_layers=1,
                 deg=None,  # Only PNA architecture
                 edge_dim=None,
                 num_nodes=None,
                 beta=1.0,
                 lambda_kld=1.0,
                 annealing_beta=False,
                 residual=0,  # Only PNA architecture
                 drop_rate=0.0,
                 drop_rate_i=0.0,
                 dropout_adj_rate=0.0,
                 dropout_adj_pa_rate=0.0,
                 dropout_adj_pa_prob_keep_self=0.0,
                 keep_self_loops=True,
                 dropout_input_rate=0.0,
                 dropout_adj_T=0,
                 act_name=Cte.RELU,
                 distr_x='normal',
                 distr_z='normal',
                 architecture=None,  # PNA, DGNN, DPNA
                 estimator='iwaedreg',
                 K=1,  # Only for IWAE estimator
                 intervention_list=None,  # In distribution interventions
                 scaler=None,
                 init=None,
                 data_is_toy=True,
                 is_heterogeneous=False,
                 norm_categorical=False):
        super(VCAUSE, self).__init__()

        self.save_hyperparameters()
        self.estimator = estimator

        self.num_nodes = num_nodes
        self.keep_self_loops = keep_self_loops

        self.random_train_sampler = None

        self.data_is_toy = data_is_toy

        if scaler is None:
            self.scaler = FunctionTransformer(func=lambda x: x,
                                              inverse_func=lambda x: x)
        else:
            self.scaler = scaler

        if beta == 10:  # Use 1/dim(z)
            self.beta = 1 / z_dim
        else:
            self.beta = beta
        self.annealing_beta = annealing_beta

        self.dropout_input_rate = dropout_input_rate

        if is_heterogeneous:
            from models.vcause.hvcause_module import HVCAUSEModule

            self.model = HVCAUSEModule(x_dim_list=x_dim,
                                     h_dim_list_dec=h_dim_list_dec,  # Hidden layers in the generative network
                                     h_dim_list_enc=h_dim_list_enc,  # Hidden layers in the inference network
                                     z_dim=z_dim,
                                     m_layers=m_layers,
                                     lambda_kld=lambda_kld,
                                     deg=deg,
                                     edge_dim=edge_dim,
                                     num_nodes=num_nodes,
                                     residual=residual,
                                     drop_rate=drop_rate,
                                     drop_rate_i=drop_rate_i,
                                     act_name=act_name,
                                     distr_x_list=distr_x,
                                     distr_z=distr_z,
                                     architecture=architecture,
                                      norm_categorical=norm_categorical,
                                     K=K
                                     )
        else:
            from models.vcause.vcause_module import VCAUSEModule


            self.model = VCAUSEModule(x_dim=x_dim,
                                     h_dim_list_dec=h_dim_list_dec,  # Hidden layers in the generative network
                                     h_dim_list_enc=h_dim_list_enc,  # Hidden layers in the inference network
                                     z_dim=z_dim,
                                     m_layers=m_layers,
                                     lambda_kld=lambda_kld,
                                     deg=deg,
                                     edge_dim=edge_dim,
                                     num_nodes=num_nodes,
                                     residual=residual,
                                     drop_rate=drop_rate,
                                     drop_rate_i=drop_rate_i,
                                     act_name=act_name,
                                     distr_x=distr_x,
                                     distr_z=distr_z,
                                     architecture=architecture,
                                     K=K
                                     )



        self.is_heterogeneous = is_heterogeneous


        self.model.set_z_prior_distr()
        self.z_dim = z_dim
        self.timer = Timer()

        self.mmd1 = MMDLoss(kernel_mul=2.0, kernel_num=5, num_samples=1000)
        self.mmd2 = MMDLoss(kernel_mul=1.0, kernel_num=10, num_samples=1000)

        if init == 'ortho':
            self.apply(init.init_weights_orthogonal)
        else:
            pass

        self.intervention_list = intervention_list

        self.dropout_adj = dropout_adj_rate
        self.dropout_adj_pa_prob_keep_self = dropout_adj_pa_prob_keep_self
        self.dropout_adj_pa = dropout_adj_pa_rate
        self.dropout_adj_T = dropout_adj_T

        self.save_dir = None

        self.total_train_time = 0.0

    def get_data_is_toy(self):
        return self.data_is_toy

    def monitor(self):
        return 'valid_iwae_100'

    def monitor_mode(self):
        return 'max'

    def set_random_train_sampler(self, sampler):
        self.random_train_sampler = sampler

    @torch.no_grad()
    def samples_aggregated_posterior(self, num_samples):
        batch = self.random_train_sampler(num_samples)
        q_z_x = self.model.encoder(batch.x, batch.edge_index, edge_attr=batch.edge_attr,
                                   return_mean=False, node_ids=batch.node_ids)
        return q_z_x.sample()



    def get_x_graph(self, data, attr):
        return self.model.get_x_graph(data, attr)

    def set_optim_params(self, optim_params, sched_params):
        self.optim_params = optim_params
        self.sched_params = sched_params

    def configure_optimizers(self):
        optim = get_optimizer(self.optim_params['name'])(self.parameters(), **self.optim_params['params'])
        if isinstance(self.sched_params, dict):
            sched = get_scheduler(self.sched_params['name'])(optim, **self.sched_params['params'])
        else:
            sched = []
        return [optim], sched

    def forward(self, data, *args, **kwargs):
        raise NotImplementedError
        return self.model(data, estimator=self.estimator)

    def get_beta_annealing_factor(self, current_epoch):
        if self.annealing_beta > 0:  # Do annealing
            return max(min((current_epoch - 10) / self.annealing_beta, 1.0), 0)
        else:
            return 1.0



    def training_step(self, batch, batch_idx):

        if self.dropout_adj > 0.0 and self.current_epoch >= self.dropout_adj_T:
            batch = batch.clone()
            batch.edge_index, batch.edge_attr = dropout_adj(batch.edge_index, batch.edge_attr,
                                                            p=self.dropout_adj, keep_self_loops=self.keep_self_loops)


        if self.dropout_adj_pa > 0.0 and self.current_epoch >= self.dropout_adj_T:
            batch = batch.clone()
            batch.edge_index, batch.edge_attr = dropout_adj_parents(batch.edge_index, batch.edge_attr,
                                                            p=self.dropout_adj_pa, prob_keep_self=self.dropout_adj_pa_prob_keep_self)

        objective, data = self.model(batch,
                                     estimator=self.estimator,
                                     beta=self.beta * self.get_beta_annealing_factor(self.current_epoch),
                                     dropout_input_rate=self.dropout_input_rate)
        self.log('train_objective', objective.item(), prog_bar=True)
        for key, value in data.items():
            self.log(f'train_{key}', value.item(), prog_bar=True)
        return -objective

    def on_train_epoch_start(self) -> None:
        self.timer.tic('train')

    def on_train_epoch_end(self, outputs) -> None:
        time = self.timer.toc('train')
        self.logger.experiment.add_scalar('train_time', time, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        objective, data = self.model(batch, estimator='elbo', beta=1.0)
        self.log('valid_elbo', objective.item(), prog_bar=True)
        for key, value in data.items():
            self.log(f'valid_{key}', value.item(), prog_bar=True)

        K = 100
        log_w = self.compute_log_w(batch, K=K)
        iwae, _ = estims.IWAE(log_w[:, :K], trick=False)
        iwae = iwae.mean()
        self.log(f'valid_iwae_{K}', iwae.item(), prog_bar=True)

        return {f'valid_iwae_{K}': iwae.item()}

    def on_validation_epoch_start(self) -> None:
        self.timer.stop('train_total')
    def on_validation_epoch_end(self) -> None:
        self.timer.resume('train_total')

    def on_test_epoch_start(self) -> None:
        self.x_test = []
        self.x_hat = []
        return

    def test_step(self, batch, batch_idx):
        objective, data = self.model(batch, estimator='elbo', beta=1.0)
        self.log('test_elbo', objective.item(), prog_bar=True)
        for key, value in data.items():
            self.log(f'test_{key}', value.item(), prog_bar=True)

        K = 20
        log_w = self.compute_log_w(batch, K=K)
        iwae, _ = estims.IWAE(log_w, trick=False)
        iwae = iwae.mean()
        self.log(f'test_iwae_{K}', iwae.item(), prog_bar=True)

    def on_test_epoch_end(self) -> None:
        return

    def on_epoch_end(self) -> None:
        self.eval()
        # log sampled images

        if self.current_epoch % 100 == 0:
            my_test_object = MyEvaluator(model=self,
                                              save_dir=self.logger.save_dir,
                                              logger=self.logger,
                                              intervention_list=self.intervention_list,
                                              current_epoch=self.current_epoch)
            my_test_object.complete_logs(data_loader=self.test_dataloader(), name='test', plots=False)

        self.train()  # IMPORTANT: Set back to train mode!
        return

    def on_fit_end(self):
        with open(os.path.join(self.logger.save_dir, 'time.json'), 'w') as f:
            json.dump({'total_train_time':  self.timer.toc('train_total')}, f)
        my_test_object = MyEvaluator(model=self,
                                          save_dir=self.logger.save_dir,
                                          logger=self.logger,
                                          intervention_list=self.intervention_list,
                                          current_epoch=self.current_epoch)
        my_test_object.complete_logs(data_loader=self.test_dataloader(), name='test', plots=False)

        return

    def on_fit_start(self):
        self.eval()
        self.timer.tic('train_total')

        img_folder = mkdir(os.path.join(self.logger.save_dir, 'images'))
        self.model.set_z_prior_distr()  # Just to move the prior to GPU if needed

    @torch.no_grad()
    def get_objective_metrics(self, data_loader, name):
        output = {}

        K = 100

        metrics = {'elbo': [], f'iwae_{K}': []}
        for idx, batch in enumerate(iter(data_loader)):
            objective, data = self.model(batch, estimator='elbo', beta=1)
            metrics['elbo'].append(objective.item())
            log_w = self.compute_log_w(batch, K=K)
            iwae_10, _ = estims.IWAE(log_w, trick=False)
            metrics[f'iwae_{K}'].append(iwae_10.mean().item())

        for key, value in metrics.items():
            output[f'{name}_{key}'] = np.mean(value)

        return output

    @torch.no_grad()
    def my_evaluator(self, dataloader, name='test', save_dir=None, plots=False):
        my_test_object = MyEvaluator(model=self,
                                          save_dir=save_dir if save_dir is not None else self.logger.save_dir,
                                          logger=self.logger,
                                          intervention_list=self.intervention_list,
                                          current_epoch=100000)
        output = my_test_object.evaluate(dataloader, name=name, plots=plots)
        return output

    def my_cf_fairness(self, data_module, save_dir = None):
        my_test_object = MyEvaluator(model=self,
                                          save_dir=save_dir if save_dir is not None else self.logger.save_dir,
                                          logger=self.logger,
                                          intervention_list=self.intervention_list,
                                          current_epoch=100000)
        output = my_test_object.evaluate_cf_fairness(data_module)
        print(output)
        return output

    def compute_log_w(self, data, K):
        return self.model.compute_log_w(data, K=K)

    def compute_log_w_dreg(self, data, K):
        return self.model.compute_log_w_dreg(data, K=K)

    @torch.no_grad()
    def get_x_gener_distribution(self, data_loader, use_links=True, use_aggregated_posterior=False):
        '''
        Get x generated distribution  w/o intervention or with diagonal adjacency.
        Parameters
        ----------
        use_links : bool
            If false, then uses an diagonal adjacency matrix to compute the distribution

        Returns
        -------
        z_list: torch.Tensor
            Latent code of the generated distribution
        x: torch.Tensor
            Generated distribution
        x_real: torch.Tensor
            distribution of the dataset (real data)
        '''
        if use_links is False:
            data_loader.dataset.diagonal_SCM()
        iterator = iter(data_loader)
        self.eval()
        x, z_list = [], []
        x_real = []

        for idx, batch in enumerate(iterator):
            if use_aggregated_posterior:
                z = self.samples_aggregated_posterior(num_samples=batch.num_graphs).to(self.device)
            else:
                z = self.model.z_prior_distr.sample([batch.num_nodes]).to(self.device)


            x_hat, _ = self.model.decoder(z, batch.edge_index, edge_attr=batch.edge_attr,
                                          return_type='sample', node_ids=batch.node_ids)


            # x_real.append(self.scaler.inverse_transform(batch.x.view(batch.num_graphs, -1)))
            x_real.append(self.get_x_graph(batch, 'x'))
            # x.append(self.scaler.inverse_transform(x_hat.view(batch.num_graphs, -1)))
            x.append(x_hat.view(batch.num_graphs, -1))

            z_list.append(z.view(batch.num_graphs, -1))

        data_loader.dataset.clean_intervention()

        return torch.cat(z_list), torch.cat(x), torch.cat(x_real)

    @torch.no_grad()
    def get_x_gener_I_distribution(self, data_loader, x_I, use_aggregated_posterior=False):
        '''
        Get x generated distribution  w/o intervention or with diagonal adjacency.
        Parameters
        ----------
        x_I : dict
             If x_I is None compute the distribution of the original SCM, if x_I is a dict
             then compute the interventional distribution. E.g. x_I = {'x1': 0} computes the
             interventional distribution with do(x1=0)
        Returns
        -------
        x_gener_dict_out: dict of torch.Tensor
            Generated distribution
        x_real_dict_out: dict of torch.Tensor
            distribution of the dataset (real data)
        '''
        assert isinstance(x_I, dict)
        data_loader.dataset.set_intervention(x_I)
        iterator = iter(data_loader)
        self.eval()
        x_gener_dict = {'intervened': [],
                        'children': [],
                        'all': []}
        x_real_dict = {'intervened': [],
                       'children': [],
                       'all': []}

        for idx, batch in enumerate(iterator):
            if use_aggregated_posterior:
                z = self.samples_aggregated_posterior(num_samples=batch.num_graphs).to(self.device)
            else:
                z = self.model.z_prior_distr.sample([batch.num_nodes]).to(self.device)

            z = z.view(batch.num_graphs, -1)

            z_mean, _ = self.model.encoder(batch.x_i, batch.edge_index_i, edge_attr=batch.edge_attr_i,
                                           return_mean=True, node_ids=batch.node_ids)
            z_mean = z_mean.reshape(batch.num_graphs, -1)
            for i, _ in data_loader.dataset.x_I.items():
                z[:, self.z_dim * i:self.z_dim * (i + 1)] = z_mean[:, self.z_dim * i:self.z_dim * (i + 1)]

            z = z.view(-1, self.z_dim)


            x_hat, _ = self.model.decoder(z, batch.edge_index_i, edge_attr=batch.edge_attr_i,
                                          return_type='sample', node_ids=batch.node_ids)

            x_hat = x_hat.reshape(batch.num_graphs, -1)
            # x_hat = self.scaler.inverse_transform(x_hat)


            if self.data_is_toy:
                # Not normalized
                x_inter, set_nodes = data_loader.dataset.sample_intervention(x_I=x_I,
                                                                             n_samples=batch.num_graphs,
                                                                             return_set_nodes=True)


                x_inter = torch.tensor(self.scaler.transform(x_inter))
                # x_real_dict['parents'].append(x_inter[:, set_nodes['parents']])
                x_real_dict['intervened'].append(x_inter[:, set_nodes['intervened']])
                x_real_dict['children'].append(x_inter[:, set_nodes['children']])
                x_real_dict['all'].append(x_inter)
            else:
                set_nodes = data_loader.dataset.get_set_nodes()

            # x_gener_dict['parents'].append(x_hat[:, set_nodes['parents']])
            x_gener_dict['intervened'].append(x_hat[:, set_nodes['intervened']])
            x_gener_dict['children'].append(x_hat[:, set_nodes['children']])
            x_gener_dict['all'].append(x_hat)


        x_gener_dict_out = {}
        x_real_dict_out = {}
        for key, values in x_gener_dict.items():
            x_gener_dict_out[key] = torch.cat(values)
            if self.data_is_toy:
                x_real_dict_out[key] = torch.cat(x_real_dict[key])

        data_loader.dataset.clean_intervention()

        return x_gener_dict_out, x_real_dict_out

    @torch.no_grad()
    def compute_counterfactual(self, batch, x_I, z_I):
        z_factual, _ = self.model.encoder(batch.x, batch.edge_index, edge_attr=batch.edge_attr,
                                          return_mean=True, node_ids=batch.node_ids)
        # print('z_cf')
        # print(z_cf[:6])

        # Encoder pass 2 CounterFactual
        z_cf_I, _ = self.model.encoder(batch.x_i, batch.edge_index_i, edge_attr=batch.edge_attr_i,
                                       return_mean=True, node_ids=batch.node_ids)
        # print('z_cf_I')
        # print(z_cf_I[:6])

        z_factual = z_factual.reshape(batch.num_graphs, -1)
        z_cf_I = z_cf_I.reshape(batch.num_graphs, -1)

        # Replace z_cf of the intervened variables with z_cf_I
        ## {'x1': 2}
        for i, _ in x_I.items():
            z_factual[:, self.z_dim * i:self.z_dim * (i + 1)] = z_cf_I[:, self.z_dim * i:self.z_dim * (i + 1)]
        for k, v in z_I.items():
            z_factual[:, self.z_dim * k:self.z_dim * (k + 1)] = v

        z_factual = z_factual.reshape(-1, self.z_dim)
        # print('z_dec')
        # print(z_cf[:6])


        x_CF, _ = self.model.decoder(z_factual, batch.edge_index_i, edge_attr=batch.edge_attr_i,
                                     return_type='sample', node_ids=batch.node_ids)

        return x_CF.view(batch.num_graphs, -1), z_factual.reshape(batch.num_graphs, -1), z_cf_I.reshape(
            batch.num_graphs, -1)

    @torch.no_grad()
    def get_x_cf_distribution(self, data_loader, x_I=None, is_noise=False, return_z=False):
        assert isinstance(x_I, dict)
        data_loader.dataset.set_intervention(x_I, is_noise=is_noise)
        iterator = iter(data_loader)
        self.eval()

        x_gener_dict = {'intervened': [],
                        'children': [],
                        'all': []}
        x_real_dict = {'intervened': [],
                       'children': [],
                       'all': []}
        x_factual_dict = {'all':[]}
        z_factual_dict = {'all':[]}
        z_counterfactual_dict = {'all': []}

        for idx, batch in enumerate(iterator):

            # Encoder pass 1 with Factual
            z_factual, _ = self.model.encoder(batch.x, batch.edge_index, edge_attr=batch.edge_attr,
                                              return_mean=True, node_ids=batch.node_ids)

            z_factual = z_factual.reshape(batch.num_graphs, -1)
            z_factual_dict['all'].append(z_factual.clone())


            # Encoder pass 2 CounterFactual
            z_cf_I, _ = self.model.encoder(batch.x_i, batch.edge_index_i, edge_attr=batch.edge_attr_i,
                                           return_mean=True, node_ids=batch.node_ids)


            z_cf_I = z_cf_I.reshape(batch.num_graphs, -1)
            z_counterfactual_dict['all'].append(z_cf_I.clone())


            # Replace z_cf of the intervened variables with z_cf_I

            for i, _ in data_loader.dataset.x_I.items():
                z_factual[:, self.z_dim * i:self.z_dim * (i + 1)] = z_cf_I[:, self.z_dim * i:self.z_dim * (i + 1)]



            z_factual = z_factual.reshape(-1, self.z_dim) # [512, 1] [3000, 1]




            x_CF, _ = self.model.decoder(z_factual, batch.edge_index_i, edge_attr=batch.edge_attr_i,
                                         return_type='sample', node_ids=batch.node_ids)


            # Not normalized
            x_CF = x_CF.view(batch.num_graphs, -1)
            if self.data_is_toy:
                x_cf_real, set_nodes = data_loader.dataset.get_counterfactual(
                    x_factual=self.scaler.inverse_transform(self.get_x_graph(batch, 'x')),
                    u_factual=batch.u.view(batch.num_graphs, -1),
                    x_I=x_I,
                    is_noise=is_noise,
                    return_set_nodes=True)
                x_cf_real = self.scaler.transform(x_cf_real)

                x_real_dict['intervened'].append(x_cf_real[:, set_nodes['intervened']])
                x_real_dict['children'].append(x_cf_real[:, set_nodes['children']])
                x_real_dict['all'].append(x_cf_real)
            else:
                set_nodes = data_loader.dataset.get_set_nodes()

            #Cf.shape [512,1] // CF.shape [1000, 3]
            x_gener_dict['intervened'].append(x_CF[:, set_nodes['intervened']])
            x_gener_dict['children'].append(x_CF[:, set_nodes['children']])
            x_gener_dict['all'].append(x_CF)

            x_factual_dict['all'].append(self.get_x_graph(batch, 'x'))




        x_gener_dict_out = {}
        x_real_dict_out = {}
        x_factual_dict_out = {}
        z_factual_dict_out = {}
        z_counterfactual_dict_out = {}
        for key, values in x_gener_dict.items():
            x_gener_dict_out[key] = torch.cat(values)
            if self.data_is_toy:
                x_real_dict_out[key] = torch.cat(x_real_dict[key])

        for key, values in x_factual_dict.items():
            x_factual_dict_out[key] = torch.cat(values)
            z_factual_dict_out[key] = torch.cat(z_factual_dict[key])
            z_counterfactual_dict_out[key] = torch.cat(z_counterfactual_dict[key])



        data_loader.dataset.clean_intervention()

        if return_z:
            return x_gener_dict_out, z_counterfactual_dict_out, x_factual_dict_out, z_factual_dict_out
        else:
            return x_gener_dict_out, x_real_dict_out, set_nodes

    @torch.no_grad()
    def get_x(self, data_loader):
        iterator = iter(data_loader)
        self.eval()
        x = []
        for idx, batch in enumerate(iterator):
            x.append(self.get_x_graph(batch, 'x'))
        return torch.cat(x)

    @torch.no_grad()
    def get_recons_obs_distribution(self, data_loader):
        '''
        Reconstruct all the feates of all the  graphs in data loader, i.e.,
        Z \sim q(Z|X, A) and X_hat p(X | Z, A)
        ----------
        data_loader : Pytorch Geometric data loader
             Description here
        Returns
        -------
        z_list: torch.Tensor
            Latent code of the reconstructed distribution, i.e. q(z|x)
        x: torch.Tensor
            reconstructed samples
        x_real: torch.Tensor
            original  samples (real data)
        '''
        iterator = iter(data_loader)
        self.eval()
        x, z = [], []
        x_real = []
        for idx, batch in enumerate(iterator):
            z_hat, x_hat = self.model.reconstruct(batch)
            x.append(x_hat.view(batch.num_graphs, -1))
            z.append(z_hat.reshape(batch.num_graphs, -1))
            x_real.append(self.get_x_graph(batch, 'x'))

        return torch.cat(z), torch.cat(x), torch.cat(x_real)

    @torch.no_grad()
    def get_obs_distribution(self, data_loader):
        iterator = iter(data_loader)
        self.eval()
        x = []
        for idx, batch in enumerate(iterator):
            x.append(batch.x.view(batch.num_graphs, -1))

        return torch.cat(x)
