import ast
import json
import os
import random
import re
import string
import tempfile
import sys
from typing import Dict

from sagemaker.parameter import (
    ContinuousParameter,
    IntegerParameter,
    CategoricalParameter,
)
from sagemaker.tuner import HyperparameterTuner


def random_id(N=10):
    return "".join(random.choices(string.ascii_uppercase + string.digits, k=N))


metric_names = [
    "mean_wQuantileLoss",
    "MASE",
    "sMAPE",
    "ND",
]


def make_metric_definitions(metric_names):
    def metric_dict(metric_name):
        """
        :param metric_name:
        :return: a sagemaker metric definition to enable Sagemaker to interpret metrics from logs
        """
        regex = rf".*gluonts\[metric-{re.escape(metric_name)}\]: ([-+]?(\d+(\.\d*)?|\.\d+))"
        return {"Name": metric_name, "Regex": regex}

    avg_training_loss_metric = {
        "Name": "training_loss",
        "Regex": r"avg_epoch_loss=([-+]?(\d+(\.\d*)?|\.\d+))",
    }
    avg_validation_loss_metric = {
        "Name": "validation_loss",
        "Regex": r"validation_avg_epoch_loss=([-+]?(\d+(\.\d*)?|\.\d+))",
    }
    epoch_metric = {"Name": "epoch", "Regex": r"epoch=(\d+)/\d+"}
    final_loss_metric = {
        "Name": "final_loss",
        "Regex": r"Final loss: ([-+]?(\d+(\.\d*)?|\.\d+))",
    }

    return [
        avg_training_loss_metric,
        avg_validation_loss_metric,
        epoch_metric,
        final_loss_metric,
    ] + [metric_dict(m) for m in metric_names]


from pathlib import Path

import gluonts

region = "us-west-2"
account = "670864377759"
role = "Admin"
sm_role = "arn:aws:iam::670864377759:role/service-role/AmazonSageMaker-ExecutionRole-20181125T162939"


def get_sm_session():
    import sagemaker
    import isengard

    client = isengard.Client()
    sess = client.get_boto3_session(account, role, region=region)
    sm_sess = sagemaker.Session(boto_session=sess)
    return sm_sess


dependencies = gluonts.__path__


metric_definitions = make_metric_definitions(
    [
        # "mean_wQuantileLoss",
        "val_sMAPE",
        "val_MASE",
        "val_ND",
    ]
)


def train_eval_sagemaker(
    exp_name: str,
    run_local,
    wait,
    instance_type,
    volume_size,
    config: Dict,
    argv,
):
    assert re.match(r"^[A-Za-z0-9-]+$", exp_name)

    run_id = ("dummy-" if config["is_dummy_run"] else "") + random_id()

    if run_local:
        from src.train_eval import main

        assert not config["hpo"]["params"], "Cannot run HPO using local"

        with tempfile.TemporaryDirectory() as tmp_dir:
            prev_dir = os.curdir
            os.chdir(tmp_dir)
            res = main(run_id=run_id, exp_name=exp_name, argv=argv)
            os.chdir(prev_dir)
            print(f"***** finished local run: {run_id} *****")
            return res

    with tempfile.TemporaryDirectory() as tmp_dir:
        base_dir = Path(tmp_dir) / "serialized_config"
        base_dir.mkdir()
        f_serialized = base_dir / "__init__.py"

        user = config["user"]
        args = argv + [f"user={user}"]

        with f_serialized.open("w") as f:
            f.write(f"config = '{json.dumps(config)}'\n")
            f.write(f"argv = {json.dumps(args)}")

        dataset_s3_path = "s3://meta-learning/datasets/"

        tags = [
            {"Key": key, "Value": str(config[key])}
            for key in ["train", "model_name", "is_dummy_run", "user"]
        ]

        sm_sess = get_sm_session()
        from sagemaker.mxnet import MXNet

        sagemaker_estim = MXNet(
            "./train_eval.py",
            dependencies=dependencies + [base_dir],
            source_dir="./src",
            instance_type=instance_type,
            tags=tags,
            volume_size=volume_size,
            instance_count=1,
            py_version="py37",
            framework_version="1.8.0",
            hyperparameters={"run_id": run_id, "exp_name": exp_name},
            sagemaker_session=sm_sess,
            metric_definitions=metric_definitions,
            role=sm_role,
            debugger_hook_config=False,
            # image_uri='bitnami/mxnet:latest',
        )

        channels = dict(
            datasets=dataset_s3_path,
        )

        is_hpo = bool(config["hpo"]["params"])

        if is_hpo:
            hpo = config["hpo"]
            hpo_params = hpo["params"]
            param_ranges = {}
            for d in hpo_params:
                for k, v in d.items():
                    m = re.match(
                        r"^\s*(float_auto|float_linear|float_log|int|cat)\((.*)\)\s*$",
                        v,
                    )
                    assert m, f"Not a valid param_range {v}"

                    p_args = [
                        ast.literal_eval(vi.strip())
                        for vi in m.group(2).split(",")
                    ]
                    if m.group(1) == "cat":
                        param_ranges[k] = CategoricalParameter(p_args)
                    elif m.group(1) == "float_auto":
                        param_ranges[k] = ContinuousParameter(
                            *p_args, scaling_type="Auto"
                        )
                    elif m.group(1) == "float_linear":
                        param_ranges[k] = ContinuousParameter(
                            *p_args, scaling_type="Linear"
                        )
                    elif m.group(1) == "float_log":
                        param_ranges[k] = ContinuousParameter(
                            *p_args, scaling_type="Logarithmic"
                        )
                    elif m.group(1) == "int":
                        param_ranges[k] = IntegerParameter(*p_args)
                    else:
                        raise RuntimeError()
            print(param_ranges)

            tuner = HyperparameterTuner(
                sagemaker_estim,
                hpo["objective"],
                param_ranges,
                objective_type=hpo["objective_type"],
                max_jobs=hpo["max_jobs"],
                max_parallel_jobs=hpo["max_parallel_jobs"],
                metric_definitions=metric_definitions,
                strategy=hpo["strategy"],
            )

            tuner.fit(channels, wait=wait, job_name=f"{exp_name}-{run_id}")
            job_name = tuner.latest_tuning_job
            if wait:
                print(f"finished tuning job: {job_name} (run_id {run_id})")
            else:
                print(f"launched tuning job: {job_name} (run_id {run_id})")
            sys.exit(0)

        sagemaker_estim.fit(
            channels, wait=wait, job_name=f"{exp_name}-{run_id}"
        )
        job_name = sagemaker_estim.latest_training_job.job_name
        if wait:
            print(f"finished job: {job_name} (run_id {run_id})")
        else:
            print(f"launched job: {job_name} (run_id {run_id})")
