import matplotlib.pyplot as plt
import torch
from tueplots import bundles
import gcip.utils.wandb_local as wandb_local

from gcip.models.model_base import BaseLightning
from gcip.utils.optimizers import build_optimizer, build_scheduler

from gcip.utils.graph import compute_stats_batch
import gcip.utils.graph as pb_graph
import gcip.utils.io as pb_io
import time
import torch.optim.lr_scheduler as t_lr

plt.rcParams.update(bundles.icml2022())


class GCIPLightning(BaseLightning):

    def __init__(self, preparator,
                 env,
                 graph_clf,
                 ppo,
                 env_steps=10,
                 ppo_steps=1,
                 warm_up_epochs=0,
                 ppo_every_n_epochs=1,
                 init_fn=None,
                 plot=False):
        super(GCIPLightning, self).__init__(preparator,
                                                init_fn=init_fn,
                                                plot=plot)

        self.env = env
        self.graph_clf = graph_clf
        self.ppo = ppo

        self.ppo_steps = ppo_steps

        self.ppo_every_n_epochs = ppo_every_n_epochs

        self.n_steps = env_steps
        self.warm_up_epochs = warm_up_epochs

        self.loss_fn = self.preparator.get_loss_fn()

        self.act_out = self.preparator.get_output_act_fn()

        self.set_input_scaler()
        self.reset_parameters()

    def reset_parameters(self):
        self.graph_clf.apply(self.init_fn)
        self.ppo.apply(self.init_fn)

    def set_input_scaler(self):
        self.input_scaler = self.preparator.get_scaler(fit=True)
        print(self.input_scaler)

    def get_batch_norm(self, batch, use_policy=False, sample=True, x_noise=0.0, edge_noise=0.0, num_samples=1):
        batch_norm = self.input_scaler.transform(batch.to(self.device), inplace=False)
        if edge_noise > 0.0:
            batch_norm = pb_graph.add_edge_noise_batch(batch=batch_norm, p=edge_noise, sort=True)
        if x_noise > 0.0:
            batch_norm = pb_graph.add_x_noise(batch=batch_norm, eps=x_noise)
        if use_policy:
            batch_norm = self.ppo.run_episode(
                batch=batch_norm.to(self.device),
                env=self.env,
                sample=sample,
                transform=False,
                num_samples=num_samples
            )

        return batch_norm

    @torch.no_grad()
    def predict(self, batch, use_policy, sample, use_act_fn, num_samples=1, x_noise=0.0, edge_noise=0.0,
                return_batch_norm=False, **kwargs):
        batch = batch.clone()
        stats_original = compute_stats_batch(batch)

        batch_norm = self.get_batch_norm(batch,
                                         use_policy=use_policy,
                                         sample=sample,
                                         x_noise=x_noise,
                                         edge_noise=edge_noise,
                                         num_samples=num_samples)
        stats_norm = compute_stats_batch(batch_norm,
                                         num_samples=num_samples if use_policy else 1)

        stats = {}
        for name, value in stats_original.items():
            value_2 = stats_norm[name]
            stats[f"{name}_ratio"] = value_2 / value

        logits = self.graph_clf(batch=batch_norm, *kwargs)
        target = self.preparator.get_target(batch_norm, dtype=None)

        if use_act_fn:
            output = [self.act_out(logits), target, stats]

        else:
            output = [logits, target, stats]

        if return_batch_norm:
            output.append(batch_norm)
        return output

    def forward(self, batch, mode='meta', use_policy=True, **kwargs):
        assert mode in ['policy', 'graph_clf']

        if mode == 'policy':
            self.ppo.train()
            self.graph_clf.eval()
            assert self.ppo.training
            loss_dict_tmp = self.ppo(shuffle=False)
            loss_dict = {}
            for key, value in loss_dict_tmp.items():
                loss_dict[f'ppo_{key}'] = value
            return loss_dict
        else:
            self.ppo.eval()
            self.graph_clf.train()
            target = self.preparator.get_target(batch, dtype=None)

            batch_norm = self.get_batch_norm(batch,
                                             use_policy=use_policy,
                                             sample=True)
            logits = self.graph_clf(batch=batch_norm.detach().clone(), *kwargs)
            loss = self.loss_fn(logits, target)
            return loss

    @property
    def train_ppo(self):
        cond1 = self.current_epoch % 1 == 0 or self.current_epoch == (self.trainer.max_epochs - 1)
        cond2 = self.current_epoch >= self.warm_up_epochs
        return cond1 and cond2

    def on_train_batch_start(self, batch, batch_idx):

        if self.train_ppo:
            self.eval()
            self.ppo.prepare_forward(env=self.env,
                                     n_steps=self.n_steps)

            self.train()

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.train_ppo:
            self.ppo.forward_end()

    @property
    def automatic_optimization(self):
        return False

    # process inside the training loop
    def training_step(self, train_batch, batch_idx):

        log_dict = {}
        tic = time.time()

        opt_clf, opt_policy = self.optimizers()

        opt_clf.zero_grad()
        loss = self(batch=train_batch,
                    mode='graph_clf',
                    use_policy=self.warm_up_epochs <= self.current_epoch)

        self.manual_backward(loss.mean())
        opt_clf.step()

        self.update_log_dict(log_dict, {'loss': loss}, key_id='_base')

        if self.train_ppo:
            for k in range(self.ppo_steps):
                opt_policy.zero_grad()
                loss_dict = self(batch=None,
                                 mode='policy',
                                 use_policy=True)
                loss = loss_dict['ppo_loss'].mean()

                self.manual_backward(loss)
                opt_policy.step()

            self.update_log_dict(log_dict, loss_dict, key_id='')
        log_dict['time_step'] = torch.tensor(time.time() - tic)
        return log_dict

    def validation_step(self, batch, batch_idx):
        self.ppo.eval()
        self.graph_clf.eval()

        log_dict = {}

        logits, target, loss_dict = self.predict(batch,
                                                 use_policy=False,
                                                 sample=False,
                                                 use_act_fn=False)

        loss_dict['loss'] = self.loss_fn(logits, target)
        loss_dict['logits'] = logits
        loss_dict['target'] = target
        self.update_log_dict(log_dict, loss_dict, key_id='_base')

        if self.current_epoch >= self.warm_up_epochs:

            for num_samples in [1, 10]:
                extra = '' if num_samples == 1 else f'_{num_samples}'

                logits, target, loss_dict = self.predict(batch,
                                                         use_policy=True,
                                                         sample=True,
                                                         num_samples=num_samples,
                                                         use_act_fn=False)

                # logits, target, loss_dict = self.predict(batch,
                #                                          use_policy=self.current_epoch >= self.warm_up_epochs,
                #                                          sample=False,
                #                                          use_act_fn=False)

                if self.train_ppo:
                    loss_dict[f'loss'] = self.loss_fn(logits, target)
                else:
                    logits_rnd = torch.randn_like(logits)
                    loss_ = self.loss_fn(logits_rnd, target)
                    loss_dict[f'loss'] = loss_ + loss_.mean().abs()
                loss_dict[f'logits'] = logits
                loss_dict['target'] = target

                self.update_log_dict(log_dict, loss_dict, key_id=extra)
        else:
            self.update_log_dict(log_dict, loss_dict, key_id='')

        return log_dict

    def test_step(self, batch, batch_idx):
        log_dict = {}

        logits, target, loss_dict = self.predict(batch,
                                                 use_policy=False,
                                                 sample=False,
                                                 use_act_fn=False)

        loss_dict['loss_base'] = self.loss_fn(logits, target)
        loss_dict['logits_base'] = logits
        loss_dict['target_base'] = target
        self.update_log_dict(log_dict, loss_dict, key_id='')

        for num_samples in [1, 10]:
            extra = '' if num_samples == 1 else f'_{num_samples}'
            logits, target, loss_dict, batch_norm = self.predict(batch,
                                                                 use_policy=True,
                                                                 sample=True,
                                                                 use_act_fn=False,
                                                                 num_samples=num_samples,
                                                                 return_batch_norm=True)
            loss_dict[f'loss'] = self.loss_fn(logits, target)
            loss_dict[f'logits'] = logits
            loss_dict[f'target'] = target
            self.update_log_dict(log_dict, loss_dict, key_id=extra)
        if self.preparator.current_split == "test" and self.preparator.add_noise:

            for eps in [0.0, 0.25, 0.5, 0.75, 1.0]:
                logits, target, loss_dict, batch_norm = self.predict(batch,
                                                                     use_policy=True,
                                                                     sample=True,
                                                                     use_act_fn=False,
                                                                     num_samples=10,
                                                                     x_noise=eps,
                                                                     return_batch_norm=True)
                loss_dict['loss'] = self.loss_fn(logits, target)
                loss_dict['logits'] = logits
                loss_dict['target'] = target

                self.update_log_dict(log_dict=log_dict, my_dict=loss_dict, key_id=f"_eps{eps}")

            for p in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
                logits, target, loss_dict, batch_norm = self.predict(batch,
                                                                     use_policy=True,
                                                                     sample=True,
                                                                     use_act_fn=False,
                                                                     num_samples=10,
                                                                     edge_noise=p,
                                                                     return_batch_norm=True)
                loss_dict['loss'] = self.loss_fn(logits, target)
                loss_dict['logits'] = logits
                loss_dict['target'] = target

                self.update_log_dict(log_dict=log_dict, my_dict=loss_dict, key_id=f"_p{p}")

        # if batch_idx == 0:
        #     self._plot_subgraph(batch, batch_norm, batch_idx,
        #                epoch='unknown',
        #                split=self.preparator.current_split)

        return log_dict

    def _plot_subgraph(self, batch, batch_norm, batch_idx, epoch, split, num_samples=10):
        batch_norm_list = batch_norm.to_data_list()
        batch_list = batch.to_data_list()
        action_tmp = []
        for i, graph in enumerate(batch_norm_list):
            j = i // num_samples
            assert graph.action.ndim == 1
            action_tmp.append(graph.action)
            if (i + 1) % num_samples == 0:
                action_ = torch.stack(action_tmp).mean(0).unsqueeze(-1)
                batch_list[j].actions = action_
                action_tmp = []

        now = self.get_now()

        filename = f"gcip_graphs--epoch={epoch}--batch_idx={batch_idx}--split={split}--now={now}.png"

        nodes_with_color_fn = None
        edges_with_color_fn = None
        if self.ppo.policy.action_refers_to == 'node':
            nodes_with_color_fn = lambda g: (g.actions.flatten().numpy(), 'Action')
        elif self.ppo.policy.action_refers_to == 'edge':
            edges_with_color_fn = lambda g: (g.actions.flatten().numpy(), 'Action')

        folder = wandb_local.sub_folder(self.logger.save_dir, 'images')
        self.preparator.plot_data_batch(batch=batch_list,
                                        folder=folder,
                                        filename=filename,
                                        show=False,
                                        num_samples=2,
                                        batch_size=(1, 2),
                                        nodes_with_color=True,
                                        nodes_with_color_fn=nodes_with_color_fn,
                                        edges_with_color_fn=edges_with_color_fn,
                                        )

    # Can return multiple optimizers and scheduling alogoithms
    # Here using Stuochastic Gradient Descent

    def configure_optimizers(self):
        opt_clf = {}
        opt_ppo = {}
        params_clf = self.graph_clf.parameters()

        opt_clf['optimizer'] = build_optimizer(optim_config=self.optim_config,
                                               params=params_clf)
        use_sched_clf = isinstance(self.optim_config.scheduler, str)
        use_sched_policy = isinstance(self.optim_config_2.scheduler, str)

        if use_sched_clf:
            opt_clf['lr_scheduler'] = build_scheduler(optim_config=self.optim_config, optimizer=opt_clf['optimizer'])
            opt_clf['monitor'] = 'val_loss'

        lr_actor = self.optim_config_2.base_lr
        lr_critic = self.optim_config_2.base_lr / 5
        assert self.optim_config_2.optimizer == 'adam'
        params_policy = self.ppo.get_optimization_config(lr_actor=lr_actor, lr_critic=lr_critic)
        opt_ppo['optimizer'] = torch.optim.Adam(params_policy)
        if use_sched_policy:
            opt_ppo['lr_scheduler'] = build_scheduler(optim_config=self.optim_config_2, optimizer=opt_ppo['optimizer'])
            opt_ppo['monitor'] = 'val_loss'

        return (opt_clf, opt_ppo)

    def do_scheduler_step(self, sch, monitor, epoch_type):
        if epoch_type == 'train':
            if isinstance(sch, list):
                for i, sch_i in enumerate(sch):
                    if i == 1 and not self.train_ppo: continue
                    if not isinstance(sch_i, t_lr.ReduceLROnPlateau): sch_i.step()
            elif sch is not None and not isinstance(sch, t_lr.ReduceLROnPlateau):
                sch.step()
        elif epoch_type == 'val':
            if isinstance(sch, list):
                for i, sch_i in enumerate(sch):
                    if i == 1 and not self.train_ppo: continue
                    if isinstance(sch_i, t_lr.ReduceLROnPlateau): sch_i.step(monitor)
            elif sch is not None and isinstance(sch, t_lr.ReduceLROnPlateau):
                sch.step(monitor)

    @torch.no_grad()
    def evaluate(self, evaluation, root, device='cpu'):
        raise NotImplementedError
