# pylint: disable=missing-function-docstring
import logging
import tempfile
from pathlib import Path
from typing import Any, Dict
import numpy as np
import torch
from sacred import Experiment
from tsbench.evaluation import SurrogateEvaluator
from tsbench.experiments.tracking import Tracker
from tsbench.surrogate import create_surrogate

ex = Experiment()


@ex.config
def experiment_config():
    # pylint: disable=unused-variable
    name = "test"
    experiment = "ts-bench"
    surrogate = "nonparametric"

    metrics = "mean_weighted_quantile_loss_mean"
    input_flags = {
        "use_simple_dataset_features": False,
        "use_seasonal_naive_performance": False,
        "use_catch22_features": False,
    }

    nonparametric = {
        "use_ranks": False,
    }
    xgboost = {
        "objective": "ranking",
    }
    mlp = {
        "objective": "ranking",
        "discount": "linear",
        "hidden_layer_sizes": [32, 32],
        "weight_decay": 0.01,
        "dropout": 0.0,
    }


@ex.automain
def main(
    _config: Dict[str, Any],
    _seed: int,
    experiment: str,
    surrogate: str,
    metrics: str,
    input_flags: Dict[str, bool],
):
    np.random.seed(_seed)
    torch.manual_seed(_seed)
    logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

    # First, get the tracker
    print("Fetching the data...")
    tracker = Tracker.for_experiment(experiment)

    # Then, initialize the surrogate
    print("Initializing the surrogate...")
    metrics = metrics.split(",")
    surrogate_instance = create_surrogate(
        surrogate,
        predict=metrics,
        tracker=tracker,
        input_flags=input_flags,
        **(_config[surrogate] if surrogate in _config else {})
    )

    # And evaluate it
    print("Evaluating the surrogate...")
    evaluator = SurrogateEvaluator(surrogate_instance, tracker=tracker, metrics=metrics)
    result = evaluator.run()

    print(result.mean())

    # Eventually, we store the results
    print("Storing the results...")
    with tempfile.TemporaryDirectory() as d:
        path = Path(d) / "results.parquet"
        result.to_parquet(path)
        ex.add_artifact(path, content_type="application/octet-stream")
