import logging
from typing import Dict
import torch
import torch.utils.data as data
from antgine.callback import Callback
from antgine.metrics.accuracy import accuracy


class LogCallback(Callback):
    """
        Logging callback, write into stdout and log.txt.
    """
    def __init__(self, dataloader: data.DataLoader, itfreq: int):
        """
        :param torch.utils.data.DataLoader dataloader: Training set dataloader.
        :param int itfreq: Print frequency.
        """
        super().__init__()
        self._dataloader = dataloader
        self._itfreq = itfreq

    def on_epoch_begin(self, epoch: int):
        logging.info('Epoch %d begin.' % epoch)

    def on_epoch_end(self, epoch: int, metrics: Dict[str, float]):
        metrics_str = ' | '.join(list(map(lambda kv: '%s %f' % (kv[0], kv[1]), metrics.items())))
        logging.info('Epoch[%d] train %s' % (epoch, metrics_str))
        logging.info('Epoch %d ended.' % epoch)

    def on_loss_end(self, epoch: int, i: int, xs: torch.Tensor,
                    ys: torch.Tensor, outputs: torch.Tensor,
                    loss: torch.Tensor):
        if i % self._itfreq == 0:
            logging.info('Epoch[%d][%d/%d] Loss: %.20f | Top1 %f%% | Top5 %f%%' % (epoch, i, len(self._dataloader), loss.item(),
                                                                         *[e * 100 for e in accuracy(outputs, ys, topk=(1,5))]))

    def on_epoch_test_end(self, epoch: int, metrics: Dict[str, float]):
        metrics_str = ' | '.join(list(map(lambda kv: '%s %f' % (kv[0], kv[1]), metrics.items())))
        logging.info('Epoch[%d] test %s' % (epoch, metrics_str))
