import time

import torch
from lightning.pytorch.callbacks import Callback


from loguru import logger


def log_metrics(pl_module, metrics):
    for m in metrics:
        pl_module.log(
            m, metrics[m], on_step=True, on_epoch=True, prog_bar=False, logger=True
        )


class UnusedParametersCallback(Callback):
    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        for name, param in pl_module.named_parameters():
            if param.grad is None:
                print(f"Unused parameter: {name}")


class LogEpochTimeCallback(Callback):

    def on_train_epoch_start(self, trainer, pl_module):
        self.epoch_start = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        curr_time = time.time()
        duration = curr_time
        pl_module.log(
            "train_info/epoch_duration_secs",
            duration,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )
        if pl_module.current_epoch % 10 == 0:
            logger.info(
                f"Done training epoch {pl_module.current_epoch}, epoch took {duration} seconds"
            )


class LogSetpTimeCallback(Callback):

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        self.step_start = time.time()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        curr_time = time.time()
        duration = curr_time - self.step_start
        pl_module.log(
            "train_info/step_duration_secs",
            duration,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )


class GradAndWeightAnalysisCallback(Callback):

    def __init__(self, debug=True, moving_avg_size=100):
        super(GradAndWeightAnalysisCallback, self).__init__()
        self.debug = debug
        self.moving_avg_size = moving_avg_size
        self.avg_grad_history = []
        self.max_grad_history = []

    def _get_avg_and_max_w(self, pl_module):

        with torch.no_grad():
            params = torch.nn.utils.parameters_to_vector(pl_module.parameters()).abs()
            return params.sum() / params.numel(), params.max()

    def on_before_optimizer_step(self, trainer, pl_module, optimizer):

        avg_w, max_w = self._get_avg_and_max_w(pl_module)
        avg_g, max_g = self._get_avg_and_max_grad(pl_module)
        metrics = {}

        metrics["avg_w_bef_step"] = avg_w
        metrics["max_w_bef_step"] = max_w
        metrics["avg_g_bef_step"] = avg_g
        metrics["max_g_bef_step"] = max_g

        if len(self.max_grad_history) > 1:
            metrics["moving_avg_max_grad_bef_step"] = sum(self.max_grad_history) / len(
                self.max_grad_history
            )
            metrics["moving_avg_avg_grad_bef_step"] = sum(self.avg_grad_history) / len(
                self.avg_grad_history
            )
            metrics["max_g_over_avg_max_g_bef_step"] = (
                max_g / metrics["moving_avg_max_grad_bef_step"]
            )
            metrics["avg_g_over_avg_avg_g_bef_step"] = (
                avg_g / metrics["moving_avg_avg_grad_bef_step"]
            )

        if len(self.max_grad_history) >= self.moving_avg_size:
            self.max_grad_history.pop(0)
        if not (max_g.isnan().any() or max_g.isinf().any()):
            self.max_grad_history.append(max_g.item())

        if len(self.avg_grad_history) >= self.moving_avg_size:
            self.avg_grad_history.pop(0)
        if not (avg_g.isnan().any() or avg_g.isinf().any()):
            self.avg_grad_history.append(avg_g.item())

        if self.debug and (avg_w.isnan().any() or max_w.isnan().any()):
            params = torch.nn.utils.parameters_to_vector(pl_module.parameters()).abs()

        log_metrics(pl_module, metrics)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):

        avg_w, max_w = self._get_avg_and_max_w(pl_module)

    def _get_avg_and_max_grad(self, pl_module):

        grad_sum = torch.tensor(0.0, device=pl_module.device)
        max_grad = torch.tensor(0.0, device=pl_module.device)
        count = 0
        for p in pl_module.parameters():
            if p.grad is not None:
                abs_grad = p.grad.abs()
                grad_sum += abs_grad.sum()
                max_grad = torch.max(max_grad, abs_grad.max())
                count += p.grad.numel()
        if count == 0:
            return torch.tensor(0.0), torch.tensor(0.0)
        return grad_sum / count, max_grad

    def _count_nan_grad(self, pl_module):
        numels, num_nans = 0, 0
        for p in pl_module.parameters():
            if p.grad is not None:
                numels += p.grad.numel()
                num_nans += p.grad.isnan().sum().item()
        return numels, num_nans

    def on_before_zero_grad(self, trainer, pl_module, optimizer):

        avg_g, max_g = self._get_avg_and_max_grad(pl_module)

        metrics = {
            "avg_g_bef_zerog": avg_g,
            "max_g_bef_zerog": max_g,
        }
        log_metrics(pl_module, metrics)

    def on_after_backward(self, trainer, pl_module):

        avg_g, max_g = self._get_avg_and_max_grad(pl_module)
        numels, num_nans = self._count_nan_grad(pl_module)

        metrics = {
            "avg_g_after_bwd": avg_g,
            "max_g_after_bwd": max_g,
        }
        log_metrics(pl_module, metrics)


class SkipNanGradCallback(Callback):

    def __init__(self, debug=True):
        super(SkipNanGradCallback, self).__init__()
        self.count = 0
        self.iter = 0

    def on_after_backward(self, trainer, pl_module):
        nan_flag = False
        self.iter += 1
        for p in pl_module.parameters():
            if p.grad is not None:

                if p.grad.isnan().any():
                    nan_flag = True
        if nan_flag:
            self.count += 1
            pl_module.zero_grad()


class SkipLargeGradients(Callback):

    def __init__(
        self,
        moving_avg_size: int = 100,
        factor_threshold: int = 5,
        min_opt_steps: int = 2000,
    ):
        self.max_g_history = []
        self.moving_avg_size = moving_avg_size
        self.factor_threshold = factor_threshold
        self.min_opt_steps = min_opt_steps

    def on_before_optimizer_step(self, trainer, pl_module, optimizer):

        max_g = max(
            param.grad.data.abs().max()
            for param in pl_module.parameters()
            if param.grad is not None
        )

        if len(self.max_g_history) >= self.moving_avg_size:
            self.max_g_history.pop(0)
        if not (max_g.isnan().any() or max_g.isinf().any()):
            self.max_g_history.append(max_g.item())

        if trainer.global_step > self.min_opt_steps:
            if len(self.max_g_history) > 1:
                moving_average = sum(self.max_g_history) / len(self.max_g_history)

                if max_g > self.factor_threshold * moving_average:

                    for param in pl_module.parameters():
                        if param.grad is not None:
                            param.grad.data.zero_()
