# pylint: disable=missing-function-docstring
import logging
import pickle
import tempfile
from pathlib import Path
from typing import Any, Dict
import numpy as np
import torch
from sacred import Experiment
from tsbench.evaluation import RecommenderEvaluator
from tsbench.experiments.tracking import Tracker
from tsbench.recommender import create_recommender
from tsbench.surrogate import create_surrogate

ex = Experiment()


@ex.config
def experiment_config():
    # pylint: disable=unused-variable
    name = "test"
    experiment = "ts-bench"

    recommender = "surrogate"
    num_recommendations = 10
    metrics = {
        "maximize": "",
        "minimize": "mean_weighted_quantile_loss_mean,latency_mean",
        "focus": "mean_weighted_quantile_loss_mean",
    }

    surrogate = {
        "name": "mlp",
        "input_flags": {
            "use_simple_dataset_features": False,
            "use_seasonal_naive_performance": False,
            "use_catch22_features": False,
        },
        "nonparametric": {
            "use_ranks": True,
        },
        "xgboost": {
            "objective": "regression",
        },
        "mlp": {
            "objective": "ranking",
            "discount": "linear",
            "hidden_layer_sizes": [32, 32],
            "weight_decay": 0.01,
            "dropout": 0.0,
        },
    }


@ex.automain
def main(
    _seed: int,
    experiment: str,
    recommender: str,
    num_recommendations: int,
    metrics: Dict[str, str],
    surrogate: Dict[str, Any],
):
    assert (
        recommender != "surrogate" or surrogate["name"] is not None
    ), "Name of surrogate must be provided if surrogate recommender is used."

    np.random.seed(_seed)
    torch.manual_seed(_seed)
    logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

    # First, get the tracker
    print("Fetching the data...")
    tracker = Tracker.for_experiment(experiment)

    # Then, potentially initialize the surrogate
    recommender_args = {
        "maximize": metrics["maximize"].split(",") if len(metrics["maximize"]) > 0 else None,
        "minimize": metrics["minimize"].split(",") if len(metrics["minimize"]) > 0 else None,
        "focus": metrics["focus"],
    }

    if recommender == "surrogate":
        print("Initializing the surrogate...")
        surrogate_metrics = [
            m
            for m in metrics["minimize"].split(",")
            if not m.startswith("latency") and not m.startswith("num_model_parameters")
        ]
        recommender_args["surrogate"] = create_surrogate(
            surrogate["name"],
            predict=surrogate_metrics,
            tracker=tracker,
            input_flags=surrogate["input_flags"],
            **(surrogate[surrogate["name"]] if surrogate["name"] in surrogate else {})
        )
    elif recommender == "optimal":
        recommender_args["tracker"] = tracker

    # Then, we can create the recommender
    print("Initializing the recommender...")
    recommender_instance = create_recommender(recommender, **recommender_args)

    # And evaluate it
    print("Evaluating the recommender...")
    evaluator = RecommenderEvaluator(
        tracker,
        recommender_instance,
        num_recommendations=num_recommendations,
    )
    recommendations = evaluator.run()

    # Eventually, we store the results
    print("Storing the results...")
    with tempfile.TemporaryDirectory() as d:
        path = Path(d) / "recommendations.pickle"
        with path.open("wb+") as f:
            pickle.dump(recommendations, f)
        ex.add_artifact(path, content_type="application/octet-stream")
