from pytorch_lightning import Callback

from utils.optim.adam_cpr import AdamCPR


class LogAdamTestbed(Callback):

    def __init__(self, log_every_n_steps: int):
        super().__init__()
        self.log_every_n_steps = log_every_n_steps

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):

        if isinstance(trainer.optimizers[0], AdamCPR):
            if trainer.global_step % self.log_every_n_steps == 0 and trainer.global_step > 0:
                stats = {}
                for optim in trainer.optimizers:
                    for param_group in optim.param_groups:
                        if 'apply_decay' in param_group.keys():
                            if param_group['apply_decay'] == True and 'constrain' in param_group['mode']:
                                for name, param in zip(param_group['names'], param_group['params']):
                                    lagmul = optim.state[param]['lagmul'].detach().item()
                                    kappa = optim.state[param]['kappa']
                                    stats[f"AdamCPR/{name}/lambda"] = lagmul
                                    stats[f"AdamCPR/{name}/kappa"] = kappa

                if trainer.loggers is not None and trainer.local_rank == 0:
                    for logger in trainer.loggers:
                        logger.log_metrics(stats, step=trainer.global_step)
