from typing import Dict
import torch
import torch.nn as nn
import torch.optim as optim
import tensorboardX
from antgine.callback import Callback
from antgine.core import flatten_module


class PlotCallback(Callback):
    """
        Callback writing information for Tensorboard's plots.
    """
    def __init__(self, model: nn.Module, optimizer: optim.Optimizer,
                 xs: torch.Tensor, writer: tensorboardX.SummaryWriter):
        """
        :param torch.nn.Module model: Model.
        :param torch.optim.Optimizer optimizer: Optimizer.
        :param torch.Tensor xs: inputs sample for activation and gradients plotting.
        :param tensorboardX.SummaryWriter writer: SummaryWriter object for Tensorboard.
        """
        super().__init__()
        self._model = model
        self._optimizer = optimizer
        self._xs = xs
        self._writer = writer

    def on_epoch_begin(self, epoch: int):
        for name, param in self._model.named_parameters():
            try:
                self._writer.add_histogram('parameters/%s' % name, param.data.cpu().numpy(), epoch)
            except:
                print(param.data.cpu().numpy())
        self._writer.add_scalar('metrics/lr', self._optimizer.param_groups[0]['lr'], epoch)

        flatmodel = flatten_module(self._model)
        activations = dict()
        # TODO refactor this :p
        handles = list(map(lambda e: e[1].register_forward_hook(lambda m, i, o: activations.update({e[0]: o.data.cpu().numpy()})), enumerate(flatmodel)))
        with torch.no_grad():
            _ = self._model(self._xs[:min(32, self._xs.size(0))])
        list(map(lambda rmhook: rmhook.remove(), handles))

        for i, v in activations.items():
            self._writer.add_histogram('activations/%s_%d' % (type(flatmodel[i]).__name__.lower(), i), v, epoch)

    def on_epoch_end(self, epoch: int, metrics: Dict[str, float]):
        for k, v in metrics.items():
            self._writer.add_scalar('metrics/train/%s' % k, v, epoch)

    def on_epoch_test_end(self, epoch: int, metrics: Dict[str, float]):
        for k, v in metrics.items():
            self._writer.add_scalar('metrics/val/%s' % k, v, epoch) # TODO val inconsistent with test name

    def on_backward_end(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                        outputs: torch.Tensor, loss: torch.Tensor):
        if i == 0:
            for name, param in self._model.named_parameters():
                if param.grad is not None:
                    self._writer.add_histogram('grads/%s' % name, param.grad.data.cpu().numpy(), epoch)
