import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_only


class LogParameterUpdates(Callback):

    def __init__(self, log_every_n_steps: int, log_quantiles: bool):
        super().__init__()
        self.log_every_n_steps = log_every_n_steps
        self.log_quantiles = log_quantiles

        self.previous_weights = {}

    @rank_zero_only
    def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

        for k, v in pl_module.named_parameters():
            self.previous_weights[k] = v.clone().detach()

    @rank_zero_only
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):

        if trainer.global_step % self.log_every_n_steps == 0:

            # q = torch.arange(.1, 1, .1).round(decimals=1).to(trainer.model.device)
            q = torch.arange(.25, 1, .25).round(decimals=2).to(trainer.model.device)
            stats = {}

            for k, v in pl_module.named_parameters():
                current_weight = v.clone().detach()

                weight_diff = self.previous_weights[k] - current_weight

                stats[f"pdiff/{k}/std"] = weight_diff.std().item()
                stats[f"pdiff/{k}/min"] = weight_diff.min().item()
                stats[f"pdiff/{k}/max"] = weight_diff.max().item()
                stats[f"pdiff/{k}/abs_mean"] = weight_diff.abs().mean().item()
                stats[f"pdiff/{k}/mean"] = weight_diff.mean().item()
                stats[f"pdiff/{k}/abs_std"] = weight_diff.abs().std().item()
                stats[f"pdiff/{k}/std"] = weight_diff.std().item()

                if self.log_quantiles and weight_diff.size().numel() < 10000000:
                    deciles = torch.quantile(weight_diff.float(), q, interpolation='linear')
                    for q_idx, d_val in enumerate(deciles):
                        stats[f"pdiff/{k}/quantile-{q[q_idx]}"] = d_val.item()

                self.previous_weights[k] = current_weight

            if trainer.loggers is not None:
                for logger in trainer.loggers:
                    logger.log_metrics(stats, step=trainer.global_step)
