"""
Code build from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html
"""
import os

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.preprocessing import FunctionTransformer
from torch.distributions import Laplace, Uniform, TransformedDistribution, SigmoidTransform

from models._evaluator import MyEvaluator
from utils.args_parser import mkdir
from utils.metrics.time import Timer
from utils.optimizers import get_optimizer, get_scheduler
from .nflib import AffineCL, NormalizingFlowModel, MLP1layer, MAF, NSF_AR, ARMLP, MLP4


class CAREFL(pl.LightningModule):
    '''
    Causal Autoregressive Flow
    '''

    def __init__(self,
                 num_nodes,
                 distr_z='laplace',
                 flow_net_class='mlp',
                 flow_architecture='spline',
                 n_layers=1,
                 n_hidden=1,
                 parity=False,
                 intervention_list=None,  # In distribution interventions
                 scaler=None,
                 init=None):
        super(CAREFL, self).__init__()

        self.save_hyperparameters()
        self.dim = num_nodes

        # prior
        if distr_z == 'laplace':
            prior = Laplace(torch.zeros(num_nodes).to(self.device), torch.ones(num_nodes).to(self.device))
        else:
            prior = TransformedDistribution(
                Uniform(torch.zeros(num_nodes).to(self.device), torch.ones(num_nodes).to(self.device)),
                SigmoidTransform().inv)
        # net type for flow parameters
        if flow_net_class == 'mlp':
            net_class = MLP1layer
        elif flow_net_class == 'mlp4':
            net_class = MLP4
        elif flow_net_class == 'armlp':
            net_class = ARMLP
        else:
            raise NotImplementedError('net_class {} not understood.'.format(self.config.flow.net_class))

        # flow type
        def ar_flow(hidden_dim):
            if flow_architecture in ['cl', 'realnvp']:
                return AffineCL(dim=num_nodes, nh=hidden_dim, scale_base=self.config.flow.scale_base,
                                shift_base=self.config.flow.shift_base, net_class=net_class, parity=parity,
                                scale=self.config.flow.scale)
            elif flow_architecture == 'maf':
                return MAF(dim=num_nodes, nh=hidden_dim, net_class=net_class, parity=parity)
            elif flow_architecture == 'spline':
                return NSF_AR(dim=num_nodes, hidden_dim=hidden_dim, base_network=net_class)
            else:
                raise NotImplementedError('Architecture {} not understood.'.format(self.config.flow.architecture))

        flow_list = [ar_flow(n_hidden) for _ in range(n_layers)]

        self.flow_model = NormalizingFlowModel(prior, flow_list).to(self.device)

        if scaler is None:
            self.scaler = FunctionTransformer(func=lambda x: x,
                                              inverse_func=lambda x: x)
        else:
            self.scaler = scaler

        self.timer = Timer()

        if init == 'ortho':
            self.apply(init.init_weights_orthogonal)
        else:
            pass

        self.intervention_list = intervention_list

        self.save_dir = None

    def monitor(self):
        return 'valid_objective'

    def monitor_mode(self):
        return 'max'

    def set_optim_params(self, optim_params, sched_params):
        self.optim_params = optim_params
        self.sched_params = sched_params

    def configure_optimizers(self):
        optim = get_optimizer(self.optim_params['name'])(self.parameters(), **self.optim_params['params'])
        if isinstance(self.sched_params, dict):
            sched = get_scheduler(self.sched_params['name'])(optim, **self.sched_params['params'])
        else:
            sched = []
        return [optim], sched

    def forward(self, data, *args, **kwargs):
        raise NotImplementedError

    def training_step(self, batch, batch_idx):
        X = batch.x.view(batch.num_graphs, -1)  # shape [1000, 3] # works because 1 Dim X

        # compute loss
        _, prior_logprob, log_det = self.flow_model(X)
        objective = torch.mean(prior_logprob + log_det)

        self.log('train_objective', objective.item(), prog_bar=True)
        return -objective

    def on_train_epoch_start(self) -> None:
        self.timer.tic('train')

    def on_train_epoch_end(self, outputs) -> None:
        time = self.timer.toc('train')
        self.logger.experiment.add_scalar('train_time', time, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        X = batch.x.view(batch.num_graphs, -1)  # shape [1000, 3] # works because 1 Dim X
        # compute loss
        _, prior_logprob, log_det = self.flow_model(X)
        objective = torch.mean(prior_logprob + log_det)

        self.log('valid_objective', objective.item(), prog_bar=True)

        return {f'valid_objective': objective.item()}

    def on_test_epoch_start(self) -> None:
        self.x_test = []
        self.x_hat = []
        return

    def test_step(self, batch, batch_idx):
        X = batch.x.view(batch.num_graphs, -1)  # shape [1000, 3] # works because 1 Dim X
        # compute loss
        _, prior_logprob, log_det = self.flow_model(X)
        objective = torch.mean(prior_logprob + log_det)

        self.log('test_objective', objective.item(), prog_bar=True)

    def on_test_epoch_end(self) -> None:
        return

    def on_epoch_end(self) -> None:
        self.eval()
        # log sampled images

        if self.current_epoch % 100 == 0:
            my_test_object = MyEvaluator(model=self,
                                         save_dir=self.logger.save_dir,
                                         logger=self.logger,
                                         intervention_list=self.intervention_list,
                                         current_epoch=self.current_epoch)
            my_test_object.complete_logs(data_loader=self.test_dataloader(), name='test', plots=False)

        self.train()  # IMPORTANT: Set back to train mode!
        return

    def on_fit_end(self):
        my_test_object = MyEvaluator(model=self,
                                     save_dir=self.logger.save_dir,
                                     logger=self.logger,
                                     intervention_list=self.intervention_list,
                                     current_epoch=self.current_epoch)
        my_test_object.complete_logs(data_loader=self.test_dataloader(), name='test', plots=False)

        return

    def on_fit_start(self):
        self.eval()

        img_folder = mkdir(os.path.join(self.logger.save_dir, 'images'))

    def _forward_flow(self, data):

        return self.flow_model.forward(data.to(self.device))[0][-1].detach()

    def _backward_flow(self, latent):

        return self.flow_model.backward(latent.to(self.device))[0][-1].detach()

    @torch.no_grad()
    def my_evaluator(self, dataloader, name='test', save_dir=None, plots=False):

        my_test_object = MyEvaluator(model=self,
                                     save_dir=save_dir if save_dir is not None else self.logger.save_dir,
                                     logger=self.logger,
                                     intervention_list=self.intervention_list,
                                     current_epoch=100000)
        output = my_test_object.evaluate(dataloader, name=name, plots=plots)
        return output

    @torch.no_grad()
    def get_x_gener_distribution(self, data_loader, use_links=True, use_aggregated_posterior=False):
        '''
        Get x generated distribution  w/o intervention or with diagonal adjacency.
        Parameters
        ----------
        use_links : bool
            If false, then uses an diagonal adjacency matrix to compute the distribution

        Returns
        -------
        z_list: torch.Tensor
            Latent code of the generated distribution
        x: torch.Tensor
            Generated distribution
        x_real: torch.Tensor
            distribution of the dataset (real data)
        '''

        iterator = iter(data_loader)
        self.eval()
        x, z_list = [], []
        x_real = []

        for idx, batch in enumerate(iterator):
            z = self.flow_model.prior.sample((batch.num_graphs,))
            x_hat = self._backward_flow(z)
            # x_real.append(self.scaler.inverse_transform(batch.x.view(batch.num_graphs, -1)))
            x_real.append(batch.x.view(batch.num_graphs, -1))
            # x.append(self.scaler.inverse_transform(x_hat.view(batch.num_graphs, -1)))
            x.append(x_hat)

            z_list.append(z)

        data_loader.dataset.clean_intervention()

        return torch.cat(z_list), torch.cat(x), torch.cat(x_real)

    @torch.no_grad()
    def get_objective_metrics(self, data_loader, name):
        output = {}

        metrics = {'log_px': []}
        for idx, batch in enumerate(iter(data_loader)):
            X = batch.x.view(batch.num_graphs, -1)  # shape [1000, 3] # works because 1 Dim X
            _, prior_logprob, log_det = self.flow_model(X)
            objective = torch.mean(prior_logprob + log_det)
            metrics['log_px'].append(objective.item())

        for key, value in metrics.items():
            output[f'{name}_{key}'] = np.mean(value)

        return output

    @torch.no_grad()
    def get_x_gener_I_distribution(self, data_loader, x_I, use_aggregated_posterior=False):
        '''
        Get x generated distribution  w/o intervention or with diagonal adjacency.
        Parameters
        ----------
        x_I : dict
             If x_I is None compute the distribution of the original SCM, if x_I is a dict
             then compute the interventional distribution. E.g. x_I = {'x1': 0} computes the
             interventional distribution with do(x1=0)
        Returns
        -------
        x_gener_dict_out: dict of torch.Tensor
            Generated distribution
        x_real_dict_out: dict of torch.Tensor
            distribution of the dataset (real data)
        '''
        assert isinstance(x_I, dict)
        assert len(x_I) == 1

        data_loader.dataset.set_intervention(x_I)
        for key, value in data_loader.dataset.x_I.items():
            x_id = key

        iterator = iter(data_loader)
        self.eval()
        x_gener_dict = {'intervened': [],
                        'children': [],
                        'all': []}
        x_real_dict = {'intervened': [],
                       'children': [],
                       'all': []}

        for idx, batch in enumerate(iterator):
            X = batch.x.view(batch.num_graphs, -1)
            X_i = batch.x_i.view(batch.num_graphs, -1)

            n_samples = X.shape[0]
            x_int = torch.zeros((1, self.dim))
            x_int[0, x_id] = X_i[0, x_id]  # Get the intervened normalized value
            z_int = self._forward_flow(x_int)[0, x_id]
            # sample from prior and ensure z_intervention_index = z_int
            z = self.flow_model.prior.sample((n_samples,))
            z_est = torch.zeros((1, self.dim))
            z[:, x_id] = z_est[:, x_id] = z_int

            # propagate the latent sample through flow
            x_hat = self._backward_flow(z)

            # Not normalized
            x_inter, set_nodes = data_loader.dataset.sample_intervention(x_I=x_I,
                                                                         n_samples=batch.num_graphs,
                                                                         return_set_nodes=True)
            x_inter = torch.tensor(self.scaler.transform(x_inter))

            x_gener_dict['intervened'].append(x_hat[:, set_nodes['intervened']])
            x_gener_dict['children'].append(x_hat[:, set_nodes['children']])
            x_gener_dict['all'].append(x_hat)

            x_real_dict['intervened'].append(x_inter[:, set_nodes['intervened']])
            x_real_dict['children'].append(x_inter[:, set_nodes['children']])
            x_real_dict['all'].append(x_inter)

        x_gener_dict_out = {}
        x_real_dict_out = {}
        for key, values in x_gener_dict.items():
            x_gener_dict_out[key] = torch.cat(values)
            x_real_dict_out[key] = torch.cat(x_real_dict[key])

        data_loader.dataset.clean_intervention()

        return x_gener_dict_out, x_real_dict_out

    @torch.no_grad()
    def compute_counterfactual(self, batch, x_I, z_I):

        for key, value in x_I.items():
            x_id = key
        X = batch.x.view(batch.num_graphs, -1)
        X_i = batch.x_i.view(batch.num_graphs, -1)
        # abduction:
        z_obs = self._forward_flow(X)
        z_cf_val = self._forward_flow(X_i)[:, x_id]
        z_obs[:, x_id] = z_cf_val
        # prediction (pass through the flow):
        x_CF = self._backward_flow(z_obs)

        return x_CF.view(batch.num_graphs, -1), z_obs.reshape(batch.num_graphs, -1), z_cf_val.reshape(
            batch.num_graphs, -1)

    @torch.no_grad()
    def get_x_cf_distribution(self, data_loader, x_I=None, is_noise=False):
        assert isinstance(x_I, dict)
        assert len(x_I) == 1
        data_loader.dataset.set_intervention(x_I, is_noise=is_noise)
        for key, value in data_loader.dataset.x_I.items():
            x_id = key

        iterator = iter(data_loader)
        self.eval()

        x_gener_dict = {'intervened': [],
                        'children': [],
                        'all': []}
        x_real_dict = {'intervened': [],
                       'children': [],
                       'all': []}

        for idx, batch in enumerate(iterator):
            X = batch.x.view(batch.num_graphs, -1)
            X_i = batch.x_i.view(batch.num_graphs, -1)
            # abduction:
            z_obs = self._forward_flow(X)
            z_cf_val = self._forward_flow(X_i)[:, x_id]
            z_obs[:, x_id] = z_cf_val
            # prediction (pass through the flow):
            x_CF = self._backward_flow(z_obs)

            # Not normalized
            x_cf_real, set_nodes = data_loader.dataset.get_counterfactual(
                x_factual=self.scaler.inverse_transform(batch.x.view(batch.num_graphs, -1)),
                u_factual=batch.u.view(batch.num_graphs, -1),
                x_I=x_I,
                is_noise=is_noise,
                return_set_nodes=True)
            x_cf_real = self.scaler.transform(torch.tensor(x_cf_real))
            x_gener_dict['intervened'].append(x_CF[:, set_nodes['intervened']])
            x_gener_dict['children'].append(x_CF[:, set_nodes['children']])
            x_gener_dict['all'].append(x_CF)

            x_real_dict['intervened'].append(x_cf_real[:, set_nodes['intervened']])
            x_real_dict['children'].append(x_cf_real[:, set_nodes['children']])
            x_real_dict['all'].append(x_cf_real)

        x_gener_dict_out = {}
        x_real_dict_out = {}
        for key, values in x_gener_dict.items():
            x_gener_dict_out[key] = torch.cat(values)
            x_real_dict_out[key] = torch.cat(x_real_dict[key])

        data_loader.dataset.clean_intervention()

        return x_gener_dict_out, x_real_dict_out, set_nodes

    def get_data_is_toy(self):
        return True

    @torch.no_grad()
    def get_recons_obs_distribution(self, data_loader):
        '''
        Reconstruct all the feates of all the  graphs in data loader, i.e.,
        Z \sim q(Z|X, A) and X_hat p(X | Z, A)
        ----------
        data_loader : Pytorch Geometric data loader
             Description here
        Returns
        -------
        z_list: torch.Tensor
            Latent code of the reconstructed distribution, i.e. q(z|x)
        x: torch.Tensor
            reconstructed samples
        x_real: torch.Tensor
            original  samples (real data)
        '''
        iterator = iter(data_loader)
        self.eval()
        x, z = [], []
        x_real = []
        for idx, batch in enumerate(iterator):
            X = batch.x.view(batch.num_graphs, -1)
            z_hat = self._forward_flow(X)
            # sample from prior and ensure z_intervention_index = z_int
            x_hat = self._backward_flow(z_hat)

            x.append(x_hat)
            z.append(z_hat)
            x_real.append(X)

        return torch.cat(z), torch.cat(x), torch.cat(x_real)
