# pylint: disable=missing-function-docstring
import logging
import pickle
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import torch
from sacred import Experiment
from tsbench.config import MODEL_REGISTRY
from tsbench.evaluation import EnsembleEvaluator
from tsbench.experiments.tracking import Tracker
from tsbench.surrogate import create_surrogate

ex = Experiment()


@ex.config
def experiment_config():
    # pylint: disable=unused-variable
    name = "test"
    experiment = "ts-bench"

    max_latency = None
    weighting = "uniform"
    size = 10
    model_class = None

    surrogate = {
        "name": "mlp",
        "metrics": "mean_weighted_quantile_loss_mean",
        "input_flags": {
            "use_simple_dataset_features": False,
            "use_seasonal_naive_performance": False,
            "use_catch22_features": False,
        },
        "nonparametric": {
            "use_ranks": False,
        },
        "xgboost": {
            "objective": "regression",
        },
        "mlp": {
            "objective": "ranking",
            "discount": "linear",
            "hidden_layer_sizes": [32, 32],
            "weight_decay": 0.01,
            "dropout": 0.0,
        },
    }


@ex.automain
def main(
    _seed: int,
    experiment: str,
    max_latency: Optional[float],
    weighting: str,
    size: int,
    model_class: Optional[str],
    surrogate: Dict[str, Any],
):
    np.random.seed(_seed)
    torch.manual_seed(_seed)
    logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

    # First, get the tracker
    print("Fetching the data...")
    tracker = Tracker.for_experiment(experiment)

    # Then, potentially initialize the surrogate
    if surrogate["name"] is not None:
        print("Initializing the surrogate...")
        metrics = surrogate["metrics"].split(",")
        surrogate_instance = create_surrogate(
            surrogate["name"],
            predict=metrics,
            tracker=tracker,
            input_flags=surrogate["input_flags"],
            **(surrogate[surrogate["name"]] if surrogate["name"] in surrogate else {})
        )
    else:
        surrogate_instance = None

    # And evaluate the ensemble that can be built
    print("Evaluating the ensemble...")
    evaluator = EnsembleEvaluator(
        tracker,
        surrogate=surrogate_instance,
        ensemble_size=size,
        ensemble_weighting=weighting,
        config_class=MODEL_REGISTRY[model_class] if model_class is not None else None,
        max_latency=max_latency,
    )
    df, configs = evaluator.run()

    # Eventually, we store the results
    print("Storing the results...")
    with tempfile.TemporaryDirectory() as d:
        df_path = Path(d) / "results.parquet"
        df.to_parquet(df_path)
        ex.add_artifact(df_path, content_type="application/octet-stream")

        config_path = Path(d) / "configs.pickle"
        with config_path.open("wb+") as f:
            pickle.dump(configs, f)
        ex.add_artifact(config_path, content_type="application/octet-stream")
