
import matplotlib.pyplot as plt
import torch
from tueplots import bundles

import gcip.utils.io as playbook_io
from gcip.models.model_base import BaseLightning
from gcip.utils.optimizers import build_optimizer, build_scheduler
from gcip.utils.graph import compute_stats_batch
import gcip.utils.wandb_local as wandb_local
import gcip.utils.graph as pb_graph
import time

plt.rcParams.update(bundles.icml2022())

import psutil


class GNNLightning(BaseLightning):

    def __init__(self, preparator,
                 model,
                 pooling_gnn=None,
                 init_fn=None,
                 plot=False):
        super(GNNLightning, self).__init__(preparator,
                                           init_fn=init_fn,
                                           plot=plot)

        self.model = model

        self.pooling_gnn = pooling_gnn

        self.loss_fn = self.preparator.get_loss_fn()

        self.act_out = self.preparator.get_output_act_fn()

        # self.plotter = INRPlotter(model=self.model, close=True, show=False)

        self.set_input_scaler()
        self.reset_parameters()

    def reset_parameters(self):
        super(GNNLightning, self).reset_parameters()
        if self.init_fn is not None and self.pooling_gnn is not None:
            self.pooling_gnn.apply(self.init_fn)

    def set_input_scaler(self):
        self.input_scaler = self.preparator.get_scaler(fit=True)
        print(self.input_scaler)

    def get_batch_norm(self, batch, batch_size=None):

        batch_norm = self.input_scaler.transform(batch.to(self.device), inplace=False)
        if batch_norm.y.ndim == 3:
            batch_norm.y = batch_norm.y.squeeze(-1)
        return batch_norm

    def forward(self, batch, **kwargs):
        # self.input_scaler.to(self.device)
        target = self.preparator.get_target(batch, dtype=None)
        batch_norm = self.get_batch_norm(batch=batch)
        if self.pooling_gnn is not None:
            batch_norm = self.pooling_gnn(batch_norm)

        output = self.model(batch_norm, **kwargs)
        loss = None
        if isinstance(output, torch.Tensor):
            loss = self.loss_fn(output, target)
        elif isinstance(output, dict):
            logits = output["logits"]

            loss = self.loss_fn(logits, target)
            for key in output.keys():
                if "loss" in key:
                    loss += output[key]
                    
        return loss

    @torch.no_grad()
    def predict(self, batch, use_act_fn=False, return_batch_norm=False, x_noise=0.0, edge_noise=0.0, **kwargs):
        target = self.preparator.get_target(batch, dtype=None)
        if self.pooling_gnn is not None:
            stats_original = compute_stats_batch(batch)
        batch_norm = self.get_batch_norm(batch=batch)
        if x_noise > 0.0:
            batch_norm = pb_graph.add_x_noise(batch=batch_norm, eps=x_noise)
        if edge_noise > 0.0:
            batch_norm = pb_graph.add_edge_noise_batch(batch=batch_norm, p=edge_noise, sort=True)
        stats = {}
        if self.pooling_gnn is not None:
            batch_norm = self.pooling_gnn(batch_norm)
            stats_norm = compute_stats_batch(batch_norm,
                                             num_samples=1)
            for name, value in stats_original.items():
                value_2 = stats_norm[name]
                stats[f"{name}_ratio"] = value_2 / value

        logits = self.model(batch_norm, **kwargs)

        if use_act_fn:
            output = [self.act_out(logits), stats]
        else:
            if isinstance(logits, dict):
                if target.ndim == 3:
                    assert(target.shape[-1] == 1)
                    target = target.squeeze(-1)
                output = [logits["logits"], target, stats]
            else:
                output = [logits, target, stats]

        if return_batch_norm:
            output.append(batch_norm)
        return output

    @property
    def automatic_optimization(self):
        return False

    def get_memory_usage(self):
        process = psutil.Process()
        mem_info = process.memory_info()
        return torch.tensor(mem_info.rss / (1024 ** 3))

    # process inside the training loop
    def training_step(self, train_batch, batch_idx):
        log_dict = {}
        # log_dict['memory_start'] = self.get_memory_usage()

        tic = time.time()
        loss_dict = {}
        opt = self.optimizers(use_pl_optimizer=False)
        loss = self(train_batch)
        loss_dict['loss'] = loss

        opt.zero_grad()
        self.manual_backward(loss.mean())
        opt.step()
        # log_dict['memory_end'] = self.get_memory_usage()

        self.update_log_dict(log_dict=log_dict, my_dict=loss_dict)

        log_dict['time_step'] = torch.tensor(time.time() - tic)
        return log_dict

    def validation_step(self, batch, batch_idx):
        batch = batch.to(self.device)
        logits, target, loss_dict = self.predict(batch)
        loss_dict['loss'] = self.loss_fn(logits, target)
        loss_dict['logits'] = logits
        loss_dict['target'] = target

        log_dict = {}

        self.update_log_dict(log_dict=log_dict, my_dict=loss_dict)

        return log_dict

    def test_step(self, batch, batch_idx):

        logits, target, loss_dict, batch_norm = self.predict(batch,
                                                             return_batch_norm=True)
        loss_dict['loss'] = self.loss_fn(logits, target)
        loss_dict['logits'] = logits
        loss_dict['target'] = target

        log_dict = {}

        self.update_log_dict(log_dict=log_dict, my_dict=loss_dict)

        if self.preparator.add_noise:

            for eps in [0.0, 0.25, 0.5, 0.75, 1.0]:
                logits, target, loss_dict, batch_norm = self.predict(batch,
                                                                     x_noise=eps,
                                                                     return_batch_norm=True)
                loss_dict['loss'] = self.loss_fn(logits, target)
                loss_dict['logits'] = logits
                loss_dict['target'] = target

                self.update_log_dict(log_dict=log_dict, my_dict=loss_dict, key_id=f"_eps{eps}")

            for p in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
                logits, target, loss_dict, batch_norm = self.predict(batch,
                                                                     edge_noise=p,
                                                                     return_batch_norm=True)
                loss_dict['loss'] = self.loss_fn(logits, target)
                loss_dict['logits'] = logits
                loss_dict['target'] = target

                self.update_log_dict(log_dict=log_dict, my_dict=loss_dict, key_id=f"_p{p}")

            if batch_idx == 0 and self.pooling_gnn is not None:
                pass
            # self._plot_subgraph(batch, batch_norm, batch_idx,
            #                     epoch='unknown',
            #                     split=self.preparator.current_split)

        return log_dict

    def _plot_subgraph(self, batch, batch_norm, batch_idx, epoch, split):
        batch_norm_list = batch_norm.to_data_list()
        batch_list = batch.to_data_list()
        action_tmp = []
        for i, graph in enumerate(batch_norm_list):
            batch_list[i].mask = graph.mask

        now = self.get_now()

        filename = f"gnn_graphs--epoch={epoch}--batch_idx={batch_idx}--split={split}--now={now}.png"

        nodes_alpha_fn = lambda g: (g.mask.float().flatten().numpy(), 'Mask')

        folder = wandb_local.sub_folder(self.logger.save_dir, 'images')
        self.preparator.plot_data_batch(batch=batch_list,
                                        folder=folder,
                                        filename=filename,
                                        show=False,
                                        num_samples=3,
                                        batch_size=(1, 3),
                                        nodes_with_color=True,
                                        title_elem_idx='y',
                                        nodes_alpha_fn=nodes_alpha_fn,
                                        edges_alpha_fn=None,
                                        )

    # Can return multiple optimizers and scheduling alogoithms
    # Here using Stuochastic Gradient Descent

    def configure_optimizers(self):
        self.lr = self.optim_config.base_lr
        playbook_io.print_debug(f"Setting lr: {self.lr}")

        if self.pooling_gnn is not None:
            params = list(self.model.parameters()) + list(self.pooling_gnn.parameters())
        else:
            params = self.model.parameters()
        opt = build_optimizer(optim_config=self.optim_config, params=params)

        output = {}

        if isinstance(self.optim_config.scheduler, str):
            sched = build_scheduler(optim_config=self.optim_config, optimizer=opt)
            output['optimizer'] = opt
            output['lr_scheduler'] = sched
            output['monitor'] = 'val_loss'
        else:
            output['optimizer'] = opt
        return output

    @torch.no_grad()
    def evaluate(self, evaluation, root, device='cpu'):
        raise NotImplementedError
