# pylint: disable=missing-function-docstring
import os
import time
from datetime import datetime, timezone
from pathlib import Path
import click
import sagemaker
import yaml
from botocore.exceptions import ClientError
from cli.utils import (
    explode_key_values,
    generate_configurations,
    iterate_configurations,
    run_sacred_script,
)
from tsbench.experiments.aws import account_id, default_session
from tsbench.experiments.aws.ecr import image_uri
from tsbench.experiments.aws.framework import CustomFramework
from tsbench.experiments.metrics.sagemaker import metric_definitions


@click.group()
@click.argument("config", type=click.Path(exists=True), nargs=1)
@click.option("--skip", default=0)
@click.pass_context
def main(ctx: click.Context, config: str, skip: int) -> None:
    ctx.obj = {"config": Path(config), "skip": skip}


@main.command()
@click.option("--aws/--local", default=False)
@click.option("--name", default="ts-bench")
@click.option("--bucket", default="ts-bench")
@click.option("--sagemaker-role", default="Sagemaker")
@click.option("--instance-type", default="ml.c5.2xlarge")
@click.option("--max-runtime", default=240)
@click.pass_context
def benchmark(
    ctx: click.Context,
    aws: bool,
    name: str,
    bucket: str,
    sagemaker_role: str,
    instance_type: str,
    max_runtime: str,
) -> None:
    # First, setup Sagemaker connection
    boto_session = default_session()
    if aws:
        sm_session = sagemaker.Session(boto_session)
    else:
        sm_session = sagemaker.LocalSession(boto_session)

    def job_factory() -> str:
        date_str = datetime.now(tz=timezone.utc).strftime("%d-%m-%Y-%H-%M-%S-%f")
        job_name = f"{name}-{date_str}"
        return job_name

    # Then, generate configs
    all_configurations = generate_configurations(ctx.obj["config"])

    # Then, we can run the training, passing parameters as required
    for configuration in iterate_configurations(all_configurations, ctx.obj["skip"]):
        # Determine which image to use based on the instance type
        if instance_type[:5] in ("ml.p3", "ml.p2", "ml.g4"):
            image = "ts-bench:gpu-latest"
        else:
            image = "ts-bench:cpu-latest"

        # Create the estimator
        estimator = CustomFramework(
            sagemaker_session=sm_session,
            role=f"arn:aws:iam::{account_id()}:role/service-role/{sagemaker_role}",
            tags=[
                {"Key": "Experiment", "Value": name},
            ],
            instance_type=instance_type if aws else "local",
            instance_count=1,
            volume_size=30,
            max_run=max_runtime * 60 * 60,
            image_uri=image_uri(image),
            source_dir=os.path.dirname(os.path.realpath(__file__)),
            output_path=f"s3://{bucket}/experiments/{name}",
            entry_point="cli/sagemaker/benchmark.py",
            debugger_hook_config=False,
            metric_definitions=metric_definitions(),
            hyperparameters=configuration,
        )

        while True:
            # Try fitting the estimator
            try:
                estimator.fit(
                    job_name=job_factory(),
                    inputs={
                        configuration["dataset"]: f"s3://{bucket}/data/{configuration['dataset']}"
                    },
                    wait=False,
                )
                break
            except ClientError as err:
                print(f"+++ Scheduling failed: {err}")
                print("+++ Sleeping for 5 minutes.")
                time.sleep(300)

        print(f">>> Launched job: {estimator.latest_training_job.name}")

    print(">>> Successfully scheduled all training jobs.")


@main.command()
@click.option("--name", default="eval-surrogates", help="The name of the benchmark.")
@click.option("--experiment", default="ts-bench", help="The name of the benchmark experiment.")
@click.pass_context
def evaluate_surrogates(ctx: click.Context, name: str, experiment: str) -> None:
    with ctx.obj["config"].open("r") as f:
        content = yaml.safe_load(f)
        configs = explode_key_values("surrogate", content)

    for configuration in iterate_configurations(configs, ctx.obj["skip"]):
        run_sacred_script(
            "evaluate_surrogate.py",
            name=name,
            experiment=experiment,
            **configuration,
        )


@main.command()
@click.option("--name", default="eval-recommenders", help="The name of the benchmark.")
@click.option("--experiment", default="ts-bench", help="The name of the benchmark experiment.")
@click.pass_context
def evaluate_recommenders(ctx: click.Context, name: str, experiment: str) -> None:
    with ctx.obj["config"].open("r") as f:
        content = yaml.safe_load(f)
        configs = explode_key_values("recommender", content)

    for configuration in iterate_configurations(configs, ctx.obj["skip"]):
        run_sacred_script(
            "evaluate_recommender.py",
            name=name,
            experiment=experiment,
            **configuration,
        )


@main.command()
@click.option("--name", default="eval-ensembles", help="The name of the benchmark.")
@click.option("--experiment", default="ts-bench", help="The name of the benchmark experiment.")
@click.pass_context
def evaluate_ensembles(ctx: click.Context, name: str, experiment: str) -> None:
    with ctx.obj["config"].open("r") as f:
        content = yaml.safe_load(f)
        configs = explode_key_values("__", content)

    for configuration in iterate_configurations(configs, ctx.obj["skip"]):
        run_sacred_script(
            "evaluate_ensemble.py",
            name=name,
            experiment=experiment,
            **{k: v for k, v in configuration.items() if k != "__"},
        )
