import copy

import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import TQDMProgressBar


class GradientClipCallback(Callback):
    def __init__(self, clip_val):
        super().__init__()
        self.clip_val = clip_val

    def on_backward_end(self, trainer, pl_module):
        torch.nn.utils.clip_grad_value_(pl_module.parameters(), self.clip_val)


class LogProgressBar(TQDMProgressBar):
    def __init__(self, print_val=False):
        super().__init__()
        self.print_val = print_val

    def on_train_epoch_end(self, trainer, pl_module):
        super().on_train_epoch_end(trainer, pl_module)
        print()
        # torch.save(trainer.model.model.state_dict(), f'../saved/label_pretrain_{trainer.current_epoch}.pt')

    def on_train_epoch_start(self, trainer, *_):
        if trainer.current_epoch:
            print()
        super().on_train_epoch_start(trainer, *_)


class SaveBestCallback(Callback):
    def __init__(self, metric, ckpt_file, ckpt_name=None, monitor_type='valid'):
        self.metric = metric
        self.ckpt_file = ckpt_file
        self.ckpt_name = ckpt_name
        self.best = None
        self.logs = None
        self.cur_epoch = 0
        self.best_epoch = 0
        self.monitor_type = monitor_type

    def save_checkpoint(self, trainer, pl_module):
        logs = trainer.callback_metrics
        if self.metric in logs:
            cur_val = logs[self.metric]
            if self.best is None or cur_val > self.best:
                self.best_epoch = self.cur_epoch
                self.best = cur_val
                self.logs = copy.deepcopy(logs)
                if self.ckpt_name is None:
                    self.ckpt_file.save(pl_module.model.state_dict())
                else:
                    self.ckpt_file.save_with_version(pl_module.model.state_dict(), self.ckpt_name)
        self.cur_epoch += 1

    def on_train_epoch_end(self, trainer, pl_module):
        if self.monitor_type == 'train':
            self.save_checkpoint(trainer, pl_module)

    def on_validation_end(self, trainer, pl_module):
        if self.monitor_type == 'valid':
            self.save_checkpoint(trainer, pl_module)

    def on_test_end(self, trainer, pl_module):
        if self.monitor_type == 'test':
            self.save_checkpoint(trainer, pl_module)

    def on_train_end(self, trainer, pl_module):
        if self.logs is not None:
            content = f'Best epoch: {self.best_epoch}, '
            content += ', '.join([f'{key}: {val}' for key, val in self.logs.items()])
            print(content)

    def on_exception(self, trainer, pl_module, exception):
        self.on_train_end(trainer, pl_module)
