from pytorch_lightning import Callback


class LogCPR(Callback):

    def __init__(self, log_every_n_steps: int):
        super().__init__()
        self.log_every_n_steps = log_every_n_steps

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):

        if pl_module.cfg.optimizer == 'adamcpr':
            if trainer.global_step % self.log_every_n_steps == 0 and trainer.global_step > 0:
                stats = {}

                for name, param in trainer.model.named_parameters():
                    if param.requires_grad:
                        state = trainer.optimizers[0].state[param]
                        lagmul = state['lagmul']
                        kappa = state['kappa']
                        prev_reg = state['prev_reg']
                        prev_reg_gradient = state['prev_reg_gradient']
                        prev_reg_second_derivative = state['prev_reg_second_derivative']
                        stats[f"cpr/{name}/lambda"] = lagmul.item()
                        stats[f"cpr/{name}/kappa"] = kappa.item()
                        stats[f"cpr/{name}/prev_reg"] = prev_reg.item()
                        stats[f"cpr/{name}/prev_reg_gradient"] = prev_reg_gradient.item()
                        stats[f"cpr/{name}/prev_reg_second_derivative"] = prev_reg_second_derivative.item()

                # for optim in trainer.optimizers:
                #     for pidx, param_group in enumerate(optim.param_groups):
                #         if 'apply_decay' in param_group.keys():
                #             if param_group['apply_decay'] == True:
                #                 for idx, name in enumerate(param_group['names']):
                #                     lagmul = optim.cpr_states[pidx][idx]['lagmul'].detach().item()
                #                     kappa = optim.cpr_states[pidx][idx]['kappa']
                #                     stats[f"AdamCPR/{name}/lambda"] = lagmul
                #                     stats[f"AdamCPR/{name}/kappa"] = kappa

                if trainer.loggers is not None and trainer.local_rank == 0:
                    for logger in trainer.loggers:
                        logger.log_metrics(stats, step=trainer.global_step)
