from graph_learning.data_setting import DataPipeline
from .utils import make_data, TaskerData
from graph_learning.utils import merge_dicts
from typing import Callable, List, Dict
from torch.utils.data import DataLoader

class TaskerChain(object):
    """ Chain of roottaskers.
    """
    def __init__(self):
        self.taskers = []
        self.data = None

    def new_roottasker(self):
        """Add new roottasker.
        """
        self.taskers.append(RootTasker())

    def add_tasker(self, tasker):
        """Add tasker to latest roottasker.
        """
        self.taskers[-1].register(tasker)

    def add_transformer(self, transformer):
        """Add transformer to latest roottasker.
        """
        self.taskers[-1].register_transformer(transformer)

    def __len__(self):
        return len(self.taskers)

    def apply(self, data):
        """Apply dataset.

        Parameters
        ----------
        data: object
            data object.
        """
        self.data = data
        return self

    def _run_epoch(self, taskers, dataset,
                   dl_func, action_step, on_action_end,
                   epoch):
        if len(taskers) <= 0:
            raise RuntimeError('No tasker defined.')
        elif len(taskers) > 1:
            tasker = taskers[0]
            tasker.apply(dataset)
            outputs = [
                self._run_epoch(
                    taskers[1:], data, action_step, on_action_end, epoch)
                for data in dl_func(tasker)
            ]
            return on_action_end(tasker, outputs)
        else:
            tasker = taskers[0]
            tasker.apply(dataset)
            outputs = [action_step(make_data(tasker, data), epoch)
                       for data in dl_func(tasker)]
            return on_action_end(tasker, outputs, epoch)

    def run_train_epoch(self, train_step,
                        on_train_end,
                        epoch=None):
        """Run train epoch.

        Parameters
        ----------
        train_step: Callable[TaskerData]
            Run training step over data in last dataloader.
        on_train_end: Callable[Tasker, List[Dict]]
            Merge training outputs after train step for taskers.

        Returns
        -------
        Dict
            Training outputs as dict.
        """
        dl_func = lambda tasker: tasker.train_dataloader()
        return self._run_epoch(self.taskers, self.data,
                               dl_func, train_step, on_train_end, epoch)

    def run_eval_epoch(self, eval_step,
                       on_eval_end,
                       epoch=None):
        """Run evaluate epoch.

        Parameters
        ----------
        eval_step: Callable[TaskerData]
            Run evaluating step over data in last dataloader.
        on_eval_end: Callable[Tasker, List[Dict]]
            Merge evaluating outputs after train step for taskers.

        Returns
        -------
        Dict
            Evaluating outputs as dict.
        """
        dl_func = lambda tasker: tasker.valid_dataloader()
        return self._run_epoch(self.taskers, self.data,
                               dl_func, eval_step, on_eval_end, epoch)

    def run_test_epoch(self, test_step,
                       on_test_end,
                       epoch=None):
        """Run test epoch.

        Parameters
        ----------
        test_step: Callable[TaskerData]
            Run testing step over data in last dataloader.
        on_test_end: Callable[Tasker, List[Dict]]
            Merge testing outputs after train step for taskers.

        Returns
        -------
        Dict
            Testing outputs as dict.
        """
        dl_func = lambda tasker: tasker.test_dataloader()
        return self._run_epoch(self.taskers, self.data,
                               dl_func, test_step, on_test_end, epoch)

class RootTasker(object):
    """ Standalone tasker unit.
    Delegation for interfaces in subtaskers.
    """
    def __init__(self):
        self.taskers = []
        self.transformer = DataPipeline()

    def register(self, tasker):
        """Add subtasker.
        """
        tasker.register(self)
        self.taskers.append(tasker)

    def register_transformer(self, transformer):
        """Add data transformer.
        """
        self.transformer.append(transformer)

    def apply(self, data):
        """Dataset binding.
        """
        self.data = data

    def __getattr__(self, name):
        """Do interface delegation.
        """
        is_cooperate = (name in ['valid_metrics', 'test_metrics', 'valid_end', 'test_end'])
        return self.call(name, is_cooperate)

    def call(self, name, cooperate=False):
        attrs = []
        for tasker in self.taskers:
            try:
                attr = tasker.__getattribute__(name)
                if not cooperate:
                    return attr
                else:
                    attrs.append(attr)
            except AttributeError:
                continue
        if not cooperate or len(attrs)==0:
            raise AttributeError(f'no attribute {name} in the tasker')
        # cooperate funcs
        def cooperate_func(*args, **kwargs):
            results = [func(*args, **kwargs)
                       for func in attrs]
            ret = merge_dicts(results)
            return ret
        return cooperate_func


class Tasker(object):
    def register(self, roottasker):
        self.roottasker = roottasker

    def __getattr__(self, name):
        """Delegate to belonging roottasker.
        """
        return getattr(self.roottasker, name)

    def valid_metrics(self, data, outputs):
        return {}

    def test_metrics(self, data, outputs):
        return {}

    def valid_end(self, outputs):
        return {}

    def test_end(self, outputs):
        return {}

class DataloaderTasker(Tasker):
    def train_dataloader(self):
        """Get train dataloader.
        Returns
        -------
        DataLoader
            Train dataloader.
        """
        raise NotImplementedError

    def valid_dataloader(self):
        """Get valid dataloader.
        Returns
        -------
        DataLoader
            Valid dataloader.
        """
        raise NotImplementedError

    def test_dataloader(self):
        """Get test dataloader.
        Returns
        -------
        DataLoader
            Test dataloader.
        """
        raise NotImplementedError
