# -*- coding: utf-8 -*-
from dataclasses import dataclass
from functools import reduce
from operator import and_
from typing import List, Dict, Type, Optional

import pandas as pd
import logging
import os

from ts_benchmark.data.data_source import (
    LocalForecastingDataSource,
    DataSource,
)
from ts_benchmark.data.suites.global_storage import GlobalStorageDataServer
from ts_benchmark.evaluation.evaluate_model import perturb_model
from ts_benchmark.models import get_models
from ts_benchmark.recording import save_log
from ts_benchmark.utils.parallel import ParallelBackend


@dataclass
class DatasetInfo:
    # the possible values of the meta-info field 'size'
    size_value: List
    # the class of data source for this dataset
    datasrc_class: Type[DataSource]


PREDEFINED_DATASETS = {
    "large_forecast": DatasetInfo(
        size_value=["large", "small"],
        datasrc_class=LocalForecastingDataSource,
    ),
    "small_forecast": DatasetInfo(
        size_value=["small"], datasrc_class=LocalForecastingDataSource
    ),
    "user_forecast": DatasetInfo(
        size_value=["user"], datasrc_class=LocalForecastingDataSource
    ),
}


def filter_data(
    metadata: pd.DataFrame, size_value: List[str], feature_dict: Optional[Dict] = None
) -> List[str]:
    """
    Filters the dataset based on given filters

    :param metadata: The meta information DataFrame.
    :param size_value: The allowed values of the 'size' meta-info field.
    :param feature_dict: A dictionary of filters where each key is a meta-info field
        and the corresponding value is the field value to keep. If None is given,
        no extra filter is applied.
    :return:
    """
    # Remove items with a value of None in feature_dict
    feature_dict = {k: v for k, v in feature_dict.items() if v is not None}

    # Use the reduce and and_ functions to filter data file names that meet the criteria
    filt_metadata = metadata
    if feature_dict is not None:
        filt_metadata = metadata[
            reduce(and_, (metadata[k] == v for k, v in feature_dict.items()))
        ]
    filt_metadata = filt_metadata[filt_metadata["size"].isin(size_value)]

    return filt_metadata["file_name"].tolist()


def _get_model_names(model_names: List[str]):
    """
    Rename models if there exists duplications.

    If a model A appears multiple times in the list, each appearance will be renamed to
    `A`, `A_1`, `A_2`, ...

    :param model_names: A list of model names.
    :return: The renamed list of model names.
    """
    s = pd.Series(model_names)
    cumulative_counts = s.groupby(s).cumcount()
    return [
        f"{model_name}_{cnt}" if cnt > 0 else model_name
        for model_name, cnt in zip(model_names, cumulative_counts)
    ]


def perturb(
    data_config: dict,
    model_config: dict,
    evaluation_config: dict,
    report_config: Optional[dict] = None,
) -> List[str]:
    """
    Execute the benchmark pipeline process

    The pipline includes loading data, building models, evaluating models, and generating reports.

    :param data_config: Configuration for data loading.
    :param model_config: Configuration for model construction.
    :param evaluation_config: Configuration for model evaluation.
    """
    # prepare data
    # TODO: move these code into the data module, after the pipeline interface is unified
    dataset_name_list = data_config.get("data_set_name", ["small_forecast"])
    if not dataset_name_list:
        dataset_name_list = ["small_forecast"]
    if isinstance(dataset_name_list, str):
        dataset_name_list = [dataset_name_list]
    for dataset_name in dataset_name_list:
        if dataset_name not in PREDEFINED_DATASETS:
            raise ValueError(f"Unknown dataset {dataset_name}.")

    data_src_type = PREDEFINED_DATASETS[dataset_name_list[0]].datasrc_class
    if not all(
        PREDEFINED_DATASETS[dataset_name].datasrc_class is data_src_type
        for dataset_name in dataset_name_list
    ):
        raise ValueError("Not supporting different types of data sources.")

    data_src: DataSource = PREDEFINED_DATASETS[dataset_name_list[0]].datasrc_class()
    data_name_list = data_config.get("data_name_list", None)
    if not data_name_list:
        data_name_list = []
        for dataset_name in dataset_name_list:
            size_value = PREDEFINED_DATASETS[dataset_name].size_value
            feature_dict = data_config.get("feature_dict", None)
            data_name_list.extend(
                filter_data(
                    data_src.dataset.metadata, size_value, feature_dict=feature_dict
                )
            )
    data_name_list = list(set(data_name_list))
    if not data_name_list:
        raise ValueError("No dataset specified.")
    data_src.load_series_list(data_name_list)
    data_server = GlobalStorageDataServer(data_src, ParallelBackend())
    data_server.start_async()

    # modeling
    model_factory_list = get_models(model_config)

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    strategy_args = evaluation_config["strategy_args"]
    perturb_type = strategy_args.get("perturb_type")
    perturb_ratio = strategy_args.get("perturb_ratio")
    perturb_name = report_config.get("perturb_name")
    save_path = report_config.get("save_path")
    output_path = os.path.join(save_path, f'perturb_{perturb_ratio}_{perturb_type}')
    os.makedirs(output_path, exist_ok=True)
    file_handler = logging.FileHandler(os.path.join(output_path, f'{perturb_name}.log'))
    logger.addHandler(file_handler)
    stream_handler = logging.StreamHandler()
    logger.addHandler(stream_handler)
    for model_factory in model_factory_list:
        perturb_model(model_factory, data_name_list, evaluation_config, logger)
