import torch
from typing import Union, List, Callable
from torch.utils.data import Dataset, DataLoader
from src.verify.trainer.callbacks import CallbackList, Callback, LogWriter
from src.verify.trainer.metrics import BaseMetric, MetricsSummary, create_metric, metric_map
# from verify.trainer.trainers import DataParallel


def _check_device(device):
    if isinstance(device, (list, tuple)):
        for d in device:
            if torch.device(d).type == 'cpu':
                raise TypeError('not support CPU & GPU parallel.')
        return torch.device("cuda")
    if device is None:
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
    else:
        device = torch.device(device)
    return device


def _create_metrics(metrics):
    if metrics is None:
        return []
    out_metrics = []
    for metric in metrics:
        if isinstance(metric, str):
            metric_out = metric_map.get(metric)()
            if metric_out is None:
                raise ModuleNotFoundError(f'`{metric}` not in defalut matirc list.')
            out_metrics.append(metric_out)
        elif isinstance(metric, BaseMetric):
            out_metrics.append(metric)
        elif callable(metric):
            out_metrics.append(create_metric(metric))
        else:
            raise TypeError(f'metric {metric} need be (str, callable, BaseMetric)')
    return out_metrics


def pre_process_metric(metrics):
    if isinstance(metrics, dict):
        return {k: pre_process_metric(v) for k, v in metrics.items()}
    if isinstance(metrics, tuple):
        return tuple([pre_process_metric(metric) for metric in metrics])
    else:
        if isinstance(metrics, torch.Tensor):
            return metrics.detach()
        return metrics


def cat_results(items, detach=True):
    if isinstance(items[0], dict):
        return {k: cat_results([item[k] for item in items]) for k in items[0].keys()}
    if isinstance(items[0], tuple):
        return tuple([cat_results(item) for item in zip(*items)])
    if isinstance(items[0], torch.Tensor):
        result = torch.cat(items)
        if detach:
            result = result.detach()
        return result


class BaseTrainer(object):
    def __init__(self, net,
                 callbacks: List[Callback] = None,
                 batch_metrics: List[Union[BaseMetric, str, Callable]] = None,
                 epoch_metrics: List[Union[BaseMetric, str, Callable]] = None,
                 device: Union[str, torch.device, list, tuple] = None,
                 parallel_dim: Union[None, int] = None):

        if isinstance(device, (tuple, list)) and len(device) > 1:
            if parallel_dim is None:
                parallel_dim = 0
            net = DataParallel(net, device, dim=parallel_dim)

        self.net = net

        self.callbacks = CallbackList(callbacks)
        self.callbacks.set_trainer(self)

        self.batch_metrics = _create_metrics(batch_metrics)
        self.epoch_metrics = _create_metrics(epoch_metrics)

        self.device = _check_device(device)

        self.net.to(self.device)

        self._early_stop = False

    @staticmethod
    def check_indices(names):
        if not isinstance(names, (list, tuple, int, str)):
            raise TypeError
        if isinstance(names, int):
            return [names]
        return names

    def select_batch(self, batch, names):
        result = []
        for name in self.check_indices(names):
            items = batch[name]
            if isinstance(items, (list, tuple)):
                result.append(tuple([item.to(self.device) for item in items]))
            else:
                result.append(items.to(self.device))
        return tuple(result)

    def change_verbose(self, verbose):
        if not any([isinstance(item, LogWriter) for item in self.callbacks.callbacks]):
            self.callbacks.append(LogWriter(self.__class__.__name__, verbose))

        for callback in self.callbacks.callbacks:
            if isinstance(callback, LogWriter):
                callback.verbose = verbose

    @staticmethod
    def _get_batch_size(batch):
        if isinstance(batch, (tuple, list)):
            return len(batch[0])
        else:
            return len(batch)

    def train(self, epochs, train_items, eval_items=None, verbose=None, *, batch_size=128):
        if verbose is not None:
            self.change_verbose(verbose)
        if isinstance(train_items, Dataset):
            train_items = DataLoader(train_items, batch_size=batch_size, shuffle=True)
        if isinstance(eval_items, Dataset):
            eval_items = DataLoader(eval_items, batch_size=batch_size)

        self.callbacks.set_params({'epochs': epochs})

        self.callbacks.on_train_begin({'epochs': epochs})
        for e in range(epochs):
            self.callbacks.on_epoch_begin(e, {'epochs': epochs})
            self.implement_train(train_items, eval_items)
            self.callbacks.on_epoch_end(e, None)
            if self._early_stop:
                break
        self.callbacks.on_train_end(None)

    def implement_train(self, train_loader, eval_loader=None):
        self.net.train()
        self.train_epoch(train_loader)
        if eval_loader is not None:
            self.net.eval()
            with torch.no_grad():
                self.evaluate(eval_loader, verbose=None, user_call=False)

    def train_epoch(self, train_loader):
        self.callbacks.on_train_epoch_begin({'batch_nums': len(train_loader)})
        metrics_summary = MetricsSummary(batch_metrics=self.batch_metrics, epoch_metrics=self.epoch_metrics)
        for i, batch in enumerate(train_loader):
            self._pack_train_on_batch(batch, i, metrics_summary)
        metrics = metrics_summary.mean()
        self.callbacks.on_train_epoch_end(metrics)

    def evaluate(self, eval_items, verbose=None, *, batch_size=128, user_call=True):
        if verbose is not None:
            self.change_verbose(verbose)
        if isinstance(eval_items, Dataset):
            eval_loader = DataLoader(eval_items, batch_size=batch_size)
        else:
            eval_loader = eval_items
        if not user_call:
            self.callbacks.on_eval_begin(None)
        metrics_summary = MetricsSummary(epoch_metrics=self.epoch_metrics)
        with torch.no_grad():
            for i, batch in enumerate(eval_loader):
                self._pack_eval_on_batch(batch, i, metrics_summary, with_callback=(not user_call))
        metrics = metrics_summary.mean()
        if not user_call:
            self.callbacks.on_eval_end(metrics)
        return metrics

    def predict(self, test_items, *, batch_size=128):
        if isinstance(test_items, Dataset):
            test_items = DataLoader(test_items, batch_size=batch_size)
        result = []
        for batch in test_items:
            result.append(self.predict_on_batch(batch))
        return cat_results(result, detach=True)

    def _pack_train_on_batch(self, batch, i, metrics_summary: MetricsSummary):
        self.callbacks.on_train_batch_begin(i, None)
        summary = pre_process_metric(self.train_on_batch(batch))
        batch_summary = metrics_summary.append(summary, self._get_batch_size(batch))
        self.callbacks.on_train_batch_end(i, batch_summary)

    def _pack_eval_on_batch(self, batch, i, metrics_summary: MetricsSummary, *, with_callback=True):
        if with_callback:
            self.callbacks.on_eval_batch_begin(i, None)
        summary = pre_process_metric(self.eval_on_batch(batch))
        batch_summary = metrics_summary.append(summary, self._get_batch_size(batch))
        if with_callback:
            self.callbacks.on_eval_batch_end(i, batch_summary)

    def train_on_batch(self, batch):
        raise NotImplementedError

    def eval_on_batch(self, batch):
        raise NotImplementedError

    def predict_on_batch(self, batch):
        raise NotImplementedError
