from typing import Dict
import random
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from antgine.callback import Callback


class SaveCallback(Callback):
    """
        This callback allow saving model, scheduler and optimizer states dicts checkpoints.
    """
    def __init__(self, model: nn.Module, optimizer: optim.Optimizer,
                 scheduler: optim.lr_scheduler._LRScheduler, filename: str,
                 freq: int, best_metric=0.0):
        """
        :param torch.nn.Module model: Model.
        :param torch.optim.Optimizer optimizer: Optimizer.
        :param torch.optim.lr_scheduler._LRScheduler scheduler: Learning rate scheduler.
        :param str filename: Where to save checkpoints.
        :param int freq: Frequency of saved checkpoints.
        """
        super().__init__()
        self._model = model
        self._optimizer = optimizer
        self._scheduler = scheduler
        self._filename = filename
        self._freq = freq
        self._best_metric = best_metric

    def on_epoch_end(self, epoch: int, metrics: Dict[str, float]):
        if epoch % self._freq == 0:
            filename = '%s_%d.pth' % (self._filename, epoch)
            logging.info('Saving checkpoint at %s' % (filename))
            torch.save({
                'epoch': epoch,
                'optimizer_state_dict': self._optimizer.state_dict(),
                'model_state_dict': self._model.state_dict(), #TODO save only on cpu
                'pyrandom_state' : random.getstate(),
                'nprandom_state' : np.random.get_state(),
                'torchrandom_state' : torch.get_rng_state(),
                'scheduler_state_dict': self._scheduler.state_dict() if self._scheduler is not None else None
            }, filename)

    # TODO merge on_epoch_end and on_test_end
    # TODO find a better way to compare which is the best model (right now it assumes top1 is the first elem of metrics)
    # TODO assume top1 is there
    def on_epoch_test_end(self, epoch: int, metrics: Dict[str, float]):
        if self._best_metric <= metrics['top1']:
            filename = '%s_best.pth' % self._filename
            logging.info('Saving best checkpoint at %s' % (filename))
            torch.save({
                'epoch': epoch,
                'optimizer_state_dict': self._optimizer.state_dict(),
                'model_state_dict': self._model.state_dict(),
                'pyrandom_state': random.getstate(),
                'nprandom_state': np.random.get_state(),
                'torchrandom_state': torch.get_rng_state(),
                'metrics': metrics,
                'scheduler_state_dict': self._scheduler.state_dict() if self._scheduler is not None else None
            }, filename)
            self._best_metric = metrics['top1']
