import logging
import os
from typing import Union, Optional, List

import torch

from . import Callback


class Checkpoint(Callback):
    def __init__(self,
                 base_dir: str,
                 save_best_only: bool = False,
                 saving_freq: int = 10,
                 save_weight_only: bool = True,
                 monitor: Optional[Union[str, List[str]]] = 'loss',
                 direction: Optional[str] = 'min',
                 verbose: bool = False):
        super(Checkpoint, self).__init__()

        self.base_dir = base_dir
        self.save_best_only = save_best_only
        self.saving_freq = saving_freq
        self.save_weight_only = save_weight_only
        self.monitor = monitor
        self.direction = direction
        self.verbose = verbose

        if self.direction.lower() == 'min':
            self._best_monitor = float('inf')
        else:
            self._best_monitor = -1 * float('inf')

        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)

        self.epoch_no = 0

    def on_epoch_begin(self, e, logs):
        self.epoch_no = e

    @property
    def best_path(self):
        return os.path.join(self.base_dir, 'best.pth')

    def save(self, net):
        if self.save_weight_only:
            torch.save(net.state_dict(), os.path.join(self.base_dir, f'{self.epoch_no + 1}.pth'))
        else:
            torch.save(net, os.path.join(self.base_dir, f'{self.epoch_no + 1}.pth'))

    def save_best(self, net):
        if self.save_weight_only:
            torch.save(net.state_dict(), self.best_path)
        else:
            torch.save(net, self.best_path)

    def on_eval_end(self, logs):
        monitor_value = self.get_monitor_value(logs)
        if not self.save_best_only and self.epoch_no % self.saving_freq == 0:
            self.save(self.trainer.net)

        if (self.direction.lower() == 'min' and monitor_value < self._best_monitor) \
                or (self.direction.lower() == 'max' and monitor_value > self._best_monitor):
            if self.verbose:
                logging.info(
                    f'The monitor improved from {self._best_monitor} to {monitor_value}, '
                    f'saving model to {self.best_path}.'
                )
            self.save_best(self.trainer.net)
            self._best_monitor = float(monitor_value)
        else:
            if self.verbose:
                logging.info(
                    f'The monitor did not improve from {self._best_monitor}.'
                )

    def get_monitor_value(self, logs):
        logs = logs or {}
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
            logging.warning('Early stopping conditioned on metric `%s` '
                            'which is not available. Available metrics are: %s',
                            self.monitor, ','.join(list(logs.keys())))
        return monitor_value
