from typing import List, Dict, Callable, Type
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from antgine.callback import Callback
from antgine.metrics.utils import AverageMeter

CallableMetric = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]


def flatten_module(module: nn.Module, noflat: List[Type[nn.Module]] = []) -> List[nn.Module]:
    """
        This function flattens the module and returns a list of barebones layers.
    A module is considered to be barebone when it has no children or if its type is in `noflat`.

    :param torch.nn.Module module: Module to be flatten.
    :param list of torch.nn.Module noflat: List of modules which won't get flattened.

    :return: Flattened module.
    :rtype: list of torch.nn.Module
    """
    if len(list(filter(lambda t: isinstance(module, t), noflat))) > 0: # type or isinstance
        return [module]
    l: List[nn.Module] = []
    for m in module.children():
        l += flatten_module(m, noflat=noflat)
    if len(l) != 0:
        return l
    else:
        return [module]


def train_epoch(epoch: int, device: str, dataloder: data.DataLoader,
                model: nn.Module, optimizer: optim.Optimizer,
                criterion: nn.Module, regularizers: List[nn.Module] = [],
                metrics: Dict[str, CallableMetric] = dict(), callbacks: List[Callback] = []):
    """
        Perform one training epoch.
    :param int epoch: Current epoch.
    :param str device: Master device on which the data and model will be placed on.
    :param torch.utils.data.DataLoader dataloader: Training loader.
    :param torch.nn.Module model: Model.
    :param torch.optim.Optimizer optimizer: Optimizer.
    :param torch.nn.Module criterion: Criterion.
    :param list of antgine.regularizer.AbstractRegularizer regularizers: List of regularizers applied with criterion.
    :param dict[str, function] metrics: Metrics being computed to evaluate model.
    :param list of antgine.callback.Callback callbacks: List of callback being called.
    """
    avgmeters = [AverageMeter() for _ in range(len(metrics))]
    keys = metrics.keys()
    model.train()
    optimizer.zero_grad()
    list(map(lambda c: c.on_epoch_begin(epoch), callbacks))
    for i, (xs, ys) in enumerate(dataloder):
        xs, ys = xs.to(device), ys.to(device)
        list(map(lambda c: c.on_forward_begin(epoch, i, xs, ys), callbacks))
        outputs = model(xs)
        list(map(lambda c: c.on_forward_end(epoch, i, xs, ys, outputs), callbacks))
        list(map(lambda c: c.on_loss_begin(epoch, i, xs, ys, outputs), callbacks))
        loss = criterion(outputs, ys)
        for reg in regularizers:
            loss += reg(epoch, i)

        list(map(lambda c: c.on_loss_end(epoch, i, xs, ys, outputs, loss), callbacks))
        list(map(lambda c: c.on_backward_begin(epoch, i, xs, ys, outputs, loss), callbacks))
        loss.backward()
        list(map(lambda c: c.on_backward_end(epoch, i, xs, ys, outputs, loss), callbacks))
        list(map(lambda c: c.on_optimizer_step_begin(epoch, i, xs, ys, outputs, loss), callbacks))
        batch_metrics = list(map(lambda k: metrics[k](outputs, ys), keys))
        for m, acc in zip(avgmeters, batch_metrics):
            m.update(acc.item(), xs.size(0))
        optimizer.step()
        list(map(lambda c: c.on_optimizer_step_end(epoch, i, xs, ys, outputs, loss), callbacks))
        optimizer.zero_grad()
    metric_avgs = dict(zip(keys, map(lambda m: m.avg, avgmeters)))
    list(map(lambda c: c.on_epoch_end(epoch, metric_avgs), callbacks))


@torch.no_grad()
def validate(dataloader: data.DataLoader, model: nn.Module,
             device: str, metrics: Dict[str, CallableMetric]=dict()):
    """
        Evaluate the model on test set.
    :param torch.utils.data.DataLoader dataloader: Testing loader.
    :param torch.nn.Module model: Model.
    :param str device: Master device on which the data and model will be placed on.
    :param dict[str, function] metrics: Metrics being computed to evaluate model.
    :return: Dictionary of metrics evaluation.
    :rtype: dict[str, float]
    """
    model.eval()
    avgmeters = [AverageMeter() for _ in range(len(metrics))]
    keys = metrics.keys()
    for xs, ys in dataloader:
        xs, ys = xs.to(device), ys.to(device)
        outputs = model(xs)
        batch_metrics = list(map(lambda k: metrics[k](outputs, ys), keys))
        for m, acc in zip(avgmeters, batch_metrics):
            m.update(acc.item(), xs.size(0))
    model.train()
    return dict(zip(keys, map(lambda m: m.avg, avgmeters)))


def train(epochs: int, device: str, train_loader: data.DataLoader,
          test_loader: data.DataLoader, model: nn.Module, optimizer: optim.Optimizer,
          criterion: nn.Module, regularizers: List[nn.Module]=[],
          metrics: Dict[str, CallableMetric]=dict(), callbacks: List[Callback]=[], start_epoch: int=0):
    """
        This function trains the model.
    :param int epochs: Number of epoch the model will be trained.
    :param str device: Master device on which the data and model will be placed on.
    :param torch.utils.data.DataLoader train_loader: Training loader.
    :param torch.utils.data.DataLoader test_loader: Testing loader.
    :param torch.nn.Module model: Model.
    :param torch.optim.Optimizer optimizer: Optimizer.
    :param torch.nn.Module criterion: Criterion.
    :param list of antgine.regularizer.AbstractRegularizer regularizers: List of regularizers applied with criterion.
    :param dict[str, function] metrics: Metrics being computed to evaluate model.
    :param list of antgine.callback.Callback callbacks: List of callback being called.
    :param int start_epoch: Starting epoch.
    """
    list(map(lambda c: c.on_train_begin(), callbacks))
    for epoch in range(start_epoch, epochs):
        train_epoch(epoch, device, train_loader, model, optimizer, criterion, regularizers, metrics, callbacks=callbacks)
        list(map(lambda c: c.on_epoch_test_begin(epoch), callbacks))
        accs = validate(test_loader, model, device, metrics)
        list(map(lambda c: c.on_epoch_test_end(epoch, accs), callbacks))
    list(map(lambda c: c.on_train_end(), callbacks))
