import collections.abc
import logging
from typing import Type, Dict, Any, Tuple, List

import sacred.utils
from sacred import Ingredient

from data.dataset import DatasetMixin
from data.transforms import DatasetSource, PipelineDataset, make_dataset_split
from utils.utils import str2cls, objspec2constructor


data_ingredient = Ingredient('dataset')


def get_dataset_class(name: str) -> Type[DatasetMixin]:
    # Get dataset class
    fqn = f'data.{name.lower()}_dataset.{name}Dataset'
    return str2cls(fqn)


def make_pipe_from_dict(pipeline: Dict[str, Dict[str, Any]], data_source: DatasetSource,
                        _log: logging.Logger = logging.root) -> PipelineDataset:
    # if len(pipeline) == 0:
    #     raise Exception

    pipe = data_source
    for name, pipe_info in pipeline.items():
        try:
            pipe_class = objspec2constructor(pipe_info, base_module='data.transforms')
            pipe = pipe_class(pipe)
        except Exception:
            _log.warning(f'Could not instantiate a Transform from the specification {pipe_info}! Ignoring.')
            continue

    return PipelineDataset(pipe)


@data_ingredient.config
def data_config():
    name = 'SWaT'
    ds_args = dict(
        training=True
    )

    # This can also be a list of dicts, one for each split
    pipeline = {}
    # Setting this will load the default pipeline for the dataset before the specified pipeline
    use_dataset_pipeline = True

    split = (0.75, 0.25)
    split_axis = 'time'


@data_ingredient.capture
def load_dataset(name: str, ds_args: Dict[str, Any], pipeline: Dict, use_dataset_pipeline: bool,
                 split: Tuple[float, ...], split_axis: str, _log: logging.Logger) -> List[PipelineDataset]:
    dataset_class = get_dataset_class(name)
    ds = dataset_class(**ds_args)

    if not isinstance(pipeline, collections.abc.Sequence):
        pipelines = [pipeline] * len(split)
    else:
        pipelines = pipeline

    assert len(pipelines) == len(split)

    ds_splits = make_dataset_split(ds, *split, axis=split_axis)
    res_splits = []
    for pipeline, ds_pipe in zip(pipelines, ds_splits):
        if use_dataset_pipeline:
            default_pipe = dict(ds.get_default_pipeline())
            sacred.utils.recursive_update(default_pipe, pipeline)
            pipeline = default_pipe

        pipe = make_pipe_from_dict(pipeline, ds_pipe, _log)
        res_splits.append(pipe)

    return res_splits
