import ast
import copy
import json
import os
import re
import string
import random
from typing import Union, Iterable, Sequence

import numpy as np
import mxnet as mx
from pathlib import Path
import multiprocessing as mp
import itertools

from gluonts.model.estimator import DummyEstimator

from gluonts.meta_tools import get_forecasts_and_series


def random_id(N=10):
    return "".join(random.choices(string.ascii_uppercase + string.digits, k=N))


def log_metric(metric_name, value):
    print(f"gluonts[metric-{metric_name}]: {value}")


import pandas as pd

from trainer_metrics_callback import TrainerMetricsCallback
from gluonts.time_feature import get_seasonality

from sacred import SETTINGS
from sacred.observers import FileStorageObserver

from gluonts.model.ensemble import EnsembleEstimator
from gluonts.model.forecast import Forecast
from gluonts.support.util import maybe_len
from sacred_utils import ImprovedS3Observer, SetRunId


SETTINGS["CAPTURE_MODE"] = "sys"

from gluonts.dataset.repository.datasets import (
    dataset_recipes,
)
from utils import get_dataset, from_config, track_time, transform_dataset

from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import (
    make_evaluation_predictions,
)
from gluonts.model.predictor import (
    SimpleParallelizedPredictor,
    Predictor,
)
from gluonts.mx.context import check_gpu_support


def get_estimator(cfg):
    model = cfg["model"]
    ensemble = cfg["ensemble"]
    estim = from_config(model)
    if ensemble:
        if isinstance(ensemble, list):
            ensemble = dict(ensemble)
        return EnsembleEstimator(
            base_estimator=estim,
            variants=ensemble,
            num_parallel_training=cfg["ensemble_num_parallel_training"],
        )
    else:
        return estim


def write_forecasts_parquet(
    forecasts: Iterable[Forecast],
    path: Union[Path, str],
    quantiles: Sequence[str] = ("0.5",),
) -> None:
    """
    Write forecasts as parquet files
    """
    rows = []
    for fcst in forecasts:
        rec = {
            "fcst_start": pd.Timestamp(fcst.start_date),
            "mean": fcst.mean,
            "freq": str(fcst.freq),
            "item_id": fcst.item_id,
        }
        for q in quantiles:
            rec[f"p_{q}"] = fcst.quantile(q)
        rows.append(rec)
    df = pd.DataFrame(rows)
    df.to_parquet(str(path), compression="zstd")


def train_eval(cfg, ex):
    train = cfg["train"]
    test = cfg["test"]
    evaluation = cfg["evaluation"]
    seed = cfg["seed"]

    np.random.seed(seed)
    mx.random.seed(seed)
    assert train in dataset_recipes
    for ti in test:
        assert ti in dataset_recipes
    no_mp = check_gpu_support() or cfg["no_mp"]

    estimator = get_estimator(cfg)
    print(f"Using estimator: {estimator}")

    train = get_dataset(train, regenerate=False)

    test_ds = {}
    for test_ds_name in test:
        tmp = get_dataset(test_ds_name, regenerate=False)
        ds = list(tmp.test)
        meta = tmp.metadata
        if evaluation["limit_examples"]:
            ds = ds[: evaluation["limit_examples"]]
        test_ds[test_ds_name] = (meta, ds)

    train_ds = list(train.train)

    periodicity = get_seasonality(cfg["freq"])
    transform_dataset_mode = cfg["transform_dataset_mode"]

    if transform_dataset_mode is not None:
        # assert False, "Not handled properly currently"

        print(f"Transforming dataset: mode={transform_dataset_mode}")
        train_ds = transform_dataset(
            train_ds,
            prediction_length=0,
            lead_time=0,
            periodicity=periodicity,
            mode=transform_dataset_mode,
        )

        test_ds = {
            ds_name: (
                meta,
                transform_dataset(
                    ds,
                    prediction_length=meta.prediction_length,
                    lead_time=0,
                    periodicity=periodicity,
                    mode=transform_dataset_mode,
                ),
            )
            for ds_name, (meta, ds) in test_ds.items()
        }

    val_ds = list(train.train) if cfg["use_train_for_val"] else list(train.test)

    if evaluation["limit_validation_examples"]:
        print(
            f"limiting validation examples to: {evaluation['limit_validation_examples']}"
        )

        # use a random state to select the same random time-series
        num_ts_val = min(evaluation['limit_validation_examples'], len(val_ds))
        val_ds = list(np.random.RandomState(seed=0).choice(
            val_ds, num_ts_val, replace=False
        ))

    if cfg["limit_training_examples"]:
        print(
            f"limiting training examples to: {cfg['limit_training_examples']}"
        )

        # use a random state to select the same random time-series
        num_ts_train = min(cfg['limit_training_examples'], len(train_ds))
        train_ds = list(np.random.RandomState(seed=0).choice(
            train_ds, num_ts_train, replace=False
        ))

    random.shuffle(train_ds)

    if cfg["ensemble"] or isinstance(estimator, (DummyEstimator,)):
        callback = None
    elif cfg["compute_learning_curves"]:
        callback = TrainerMetricsCallback(
            train_dataset=train_ds,
            val_dataset=val_ds,
            num_val=cfg["num_val_callback"],
            eval_interval=cfg["eval_interval"],
            has_gpu=no_mp,
            prediction_length=cfg["train_prediction_length"]
        )
    else:
        callback = None

    with track_time(ex, "training"):
        if callback is not None:
            predictor = estimator.train(train_ds, callback=callback)
        else:
            predictor = estimator.train(train_ds)

    predictor_path = Path("./predictor.tar.gz")
    predictor.to_tgz(predictor_path)
    pred2 = Predictor.read_tgz(predictor_path)
    assert pred2 == predictor
    ex.add_artifact(predictor_path)

    num_eval_workers = (
        evaluation["num_workers"]
        if evaluation["num_workers"] is not None
        else mp.cpu_count()
    )
    if num_eval_workers and not no_mp:
        # predictor = ParallelizedPredictor(
        predictor = SimpleParallelizedPredictor(
            predictor,
            num_workers=num_eval_workers,
        )

    evaluator = Evaluator(quantiles=evaluation["quantiles"], num_workers=0)

    for k, (meta, ds) in test_ds.items():
        with track_time(ex, f"{k}_prediction"):
            print(
                f"Predicting {k} using prediction length {meta.prediction_length}"
            )
            forecasts, series = get_forecasts_and_series(
                ds,
                predictor,
                prediction_length=meta.prediction_length,
                num_samples=evaluation["num_samples"],
            )

        forecasts_parquet = f"./forecasts_{k}.parquet"
        write_forecasts_parquet(forecasts, forecasts_parquet)
        ex.add_artifact(forecasts_parquet)

        agg_metrics, item_metrics = evaluator(
            series, forecasts, num_series=maybe_len(ds)
        )
        item_metrics_parquet = f"item_metrics_{k}.parquet"
        item_metrics.to_parquet(item_metrics_parquet, compression="zstd")
        ex.add_artifact(item_metrics_parquet)

        for metric_name, v in agg_metrics.items():
            m = f"{k}_{metric_name}"
            ex.log_scalar(m, v)
            log_metric(m, v)

        df = pd.DataFrame({"value": agg_metrics})
        df.index.name = "metric"
        agg_metrics_path = f"./agg_metrics_{k}.csv"
        df.to_csv(agg_metrics_path)
        ex.add_artifact(agg_metrics_path)

    if callback is not None:
        df_learning_curves = pd.DataFrame(callback.metrics)
        learning_curves_path = "./learning_curves_metrics.csv"
        df_learning_curves.to_csv(learning_curves_path, index=False)
        ex.add_artifact(learning_curves_path)

    # validation metrics
    forecast_it, ts_it = make_evaluation_predictions(
        val_ds, predictor=predictor, num_samples=evaluation["num_samples"],
    )

    val_agg_metrics, _ = evaluator(
        ts_it, forecast_it, num_series=maybe_len(val_ds)
    )
    for k, v in val_agg_metrics.items():
        key = f"val_{k}"
        ex.log_scalar(key, v)
        log_metric(key, v)

    val_df = pd.DataFrame({"value": val_agg_metrics})
    val_df.index.name = "metric"
    val_agg_metrics_path = "./val_agg_metrics.csv"
    val_df.to_csv(val_agg_metrics_path)
    ex.add_artifact(val_agg_metrics_path)


def main(run_id: str, exp_name, argv):
    assert run_id
    import os, sys

    sys.path.append(os.path.dirname(__file__))

    from expconfig import ex

    assert re.match(r"^[A-Za-z0-9!-_.*\'()]+$", exp_name)

    local_run = "TRAINING_JOB_NAME" not in os.environ

    @ex.pre_run_hook
    def set_logger_stream(_run):
        _run.root_logger.handlers[0].stream = sys.stdout

    ex.observers.append(SetRunId(run_id))
    ex.add_config(
        {"sm": {"job_name": os.environ.get("TRAINING_JOB_NAME", "local_run")}}
    )

    if local_run:
        this_file_path = Path(__file__).parent.absolute()
        base_dir = str(this_file_path.parent / "my_runs" / exp_name)
        f_obs = FileStorageObserver(basedir=base_dir)
        ex.observers.append(f_obs)
    else:
        # Fill in with the correct S3 bucket
        s3_obs = ImprovedS3Observer(
            bucket="",
            basedir=f"experiments/{exp_name}",
            region="us-west-2",
        )
        ex.observers.append(s3_obs)

    @ex.main
    def run(_config):
        print(f"run using config:\n{json.dumps(_config, indent=2)}")
        train_eval(_config, ex)

    ex.run_commandline(argv)
    # ex.run()


if __name__ == "__main__":
    # mp.set_start_method("spawn", force=True)
    # from serialized_config import config, argv

    # In sagemaker this file will be available in path and contain the config
    from serialized_config import argv

    print(f"argv = {argv}")
    arg = copy.copy(argv)

    try:
        with open("/opt/ml/input/config/hyperparameters.json", "r") as f:
            hp = json.load(f)
            exp_name = hp.pop("exp_name").strip(' "')
            for k, v in hp.items():
                try:
                    v = ast.literal_eval(v)
                except:
                    pass
                arg.append(f"{k}={v}")
    except:
        raise RuntimeError("This is only supposed to be called by sagemaker")

    job_name = os.environ["TRAINING_JOB_NAME"].strip(' "')
    prefix = f"{exp_name}-"
    assert job_name.startswith(prefix)
    run_id = job_name[len(prefix) :]

    main(
        run_id=run_id,
        exp_name=exp_name,
        argv=arg,
    )
