from typing import Optional
import os

import numpy as np
from gluonts.model.estimator import DummyEstimator
from gluonts.model.meta_autoreg_det import MetaARDetEstimator
from gluonts.model.r_forecast import RForecastPredictor
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.model.tft import TemporalFusionTransformerEstimator
from gluonts.mx.distribution import StudentTOutput, GaussianOutput
from gluonts.mx.model.estimator import GluonEstimator
from gluonts.mx.trainer.model_averaging import SelectNLastMean
from gluonts.transform import ExpectedNumInstanceSampler
from sacred import Experiment

from gluonts.model.deepar import DeepAREstimator
from gluonts.model.n_beats import NBEATSEnsembleEstimator, NBEATSEstimator
from gluonts.mx.trainer import Trainer

from utils import (
    to_config,
    get_git_commit,
    normalize_freq,
    collect_dataset_stats,
)
from gluonts.meta_tools import MinHistorySampler


ex = Experiment()


@ex.config
def my_config():
    exp_name = "meta"
    user = os.environ.get("USER", "unknown")
    is_dummy_run = False
    git = get_git_commit()
    no_mp = True  # if set to True avoid using multiprocessing for evaluation (can be faster this way)

    if not is_dummy_run and git:
        assert not git["is_dirty"], "Dirty workspace"

    tags = None

    # sagemaker
    sm = dict(
        run_local=False,
        job_name=None,
        instance_type="ml.c5.9xlarge",
        wait=False,
        volume_size=200,
    )

    ### datasets
    train = "m4_monthly"
    test = ["m3_monthly"]
    train_stats = collect_dataset_stats(train, which="train")
    test_stats = [collect_dataset_stats(ti, which="test") for ti in test]
    freq = normalize_freq(train_stats["freq"])
    prediction_length = max(ti["prediction_length"] for ti in test_stats)
    train_prediction_length = None
    use_train_for_val = True  # if true uses a subset of the train set (instead of the test set) of the source for validation

    if train_prediction_length is None:
        train_prediction_length = prediction_length

    transform_dataset_mode = None

    # context_length = {"1H": 7 * 24, "1M": 12, "1Q": 8, "1Y": 10}[freq]
    compute_learning_curves = True
    num_val_callback = 100
    eval_interval = 10

    season = {"1H": 7 * 24, "1M": 12, "1Q": 4, "1D": 7, "1Y": 10, "1W": 4}[freq]

    history_len = 2 * season
    context_length_mult = 1
    context_length = int(context_length_mult * train_prediction_length)
    limit_training_examples = None  # used for experiments with reduced dataset size
    epochs = 200
    batch_size = 32
    test_batch_size = 1000  # only for MetaARDet
    learning_rate = 1e-3
    minimum_learning_rate = 5e-5  # used in combination with patience
    patience = 20
    num_avg_checkpoint = 5
    weight_decay = 1e-8
    clip_gradient = 10
    init = "xavier"
    hybridize = False  # MetaARDet does not support hybridization

    n_lags_until = None
    n_lags_until_local = None
    lags_seq = (
        None if n_lags_until is None else list(range(1, n_lags_until + 1))
    )
    lags_seq_local = (
        None
        if n_lags_until_local is None
        else list(range(1, n_lags_until_local + 1))
    )
    scaling = True

    is_m_competition = "m4" in train
    add_time_features = not is_m_competition
    time_features = None if add_time_features else []
    no_age_and_time = True

    # assert history_len >= context_length, f"{history_len} >= {context_length}"

    model_name = None
    deterministic_loss = "sMAPE"  # used for point forecast methods implemented with MetaARDet and NBEATS
    distr_output = "student"  # used for DeepAR
    dim_transform = 40  # dimension of the representation
    iterate_forecasts_during_training = True

    residual_linear = False

    model = get_model_cfg(
        model_name,
        context_length=context_length,
        train_prediction_length=train_prediction_length,
        prediction_length=prediction_length,
        freq=freq,
        history_len=history_len,
        epochs=epochs,
        batch_size=batch_size,
        test_batch_size=test_batch_size,
        det_loss=deterministic_loss,
        lags_seq=lags_seq,
        lags_seq_local=lags_seq_local,
        scaling=scaling,
        time_features=time_features,
        no_age_and_time=no_age_and_time,
        learning_rate=learning_rate,
        minimum_learning_rate=minimum_learning_rate,
        patience=patience,
        num_avg_checkpoint=num_avg_checkpoint,
        weight_decay=weight_decay,
        init=init,
        distr_output_str=distr_output,
        dim_transform=dim_transform,
        iterate_forecasts_during_training=iterate_forecasts_during_training,
        residual_linear=residual_linear,
        clip_gradient=clip_gradient,
        hybridize=hybridize
    )

    hpo = {
        "strategy": "Random",
        "objective": "val_sMAPE",
        "objective_type": "Minimize",
        "max_jobs": 2,
        "max_parallel_jobs": 10,
        # use [{'epochs': 'int(50, 100)'}]'
        # notation:
        #   int(a, b) -> IntegerParameter(a, b)
        #   float_auto(0.01, 10, 'Auto') -> ContinuousParameter(0.01, 10, scaling_type="Auto")
        #       Valid values: ‘float_auto’, ‘float_linear’, ‘float_log’
        #   cat(a, b, c) -> CategoricalParameter([a, b, c])
        "params": [],
    }

    ensemble_num_parallel_training = 10
    context_length_ensemble = {
        "n": 0,
        "min_mult": context_length_mult,
        "max_mult": 2 * context_length_mult,
    }

    if context_length_ensemble and context_length_ensemble["n"]:
        ensemble = [
            (
                "context_length",
                np.linspace(
                    context_length_ensemble["min_mult"]
                    * train_prediction_length,
                    context_length_ensemble["max_mult"]
                    * train_prediction_length,
                    context_length_ensemble["n"],
                )
                .round()
                .astype(int)
                .tolist(),
            )
        ]
    else:
        ensemble = None

    num_samples = 1 if model_name not in ["deepar", "simple_feed"] and "prob" not in model_name else 400

    evaluation = dict(
        num_workers=None,
        num_samples=num_samples,
        quantiles=[0.5],
        limit_examples=None,
        limit_validation_examples=8000,
    )

    if evaluation["limit_examples"] is not None:
        assert (
            is_dummy_run
        ), f"Can only use evaluation.limit_examples in dummy mode. Found value: {evaluation['limit_examples']}"


### test setting

@ex.named_config
def dummy():
    epochs = 1
    model = {"trainer": {"num_batches_per_epoch": 1}}
    evaluation = {"limit_examples": 1, "limit_validation_examples": 1}
    is_dummy_run = True


@ex.named_config
def local():
    sm = dict(run_local=True)


@ex.named_config
def no_training():
    epochs = 1
    model = {"trainer": {"num_batches_per_epoch": 1}}
    learning_rate = 1e-24  # 0 and smaller values are not allowed :(


@ex.named_config
def no_age_and_time_features():
    n_lags_until = 40
    add_time_features = False
    no_age_and_time = False


@ex.named_config
def hyndman():
    n_lags_until = 40
    dim_transform = 20
    transform_dataset_mode = "MASE_scaling"
    num_val_callback = 100
    eval_interval = 20
    batch_size = 1024
    scaling = False
    add_time_features = False
    no_age_and_time = False
    deterministic_loss = "MAE"
    patience = 100
    minimum_learning_rate = 5e-8


@ex.named_config
def do_hpo():
    hpo = {
        "max_jobs": 200,
        "max_parallel_jobs": 20,
        "params": [
            {
                "epochs": "cat(500, 1000)",
                "batch_size": "cat(32, 64, 128)",
                "learning_rate": "float_log(1.0E-5, 2.0E-3)",
                "context_length_mult": "float_auto(0.3, 5.0)",
                "dim_transform": "int(20, 50)",
                "history_len": "int(24, 100)",
            }
        ],
    }
    sm = {"instance_type": "ml.c5.4xlarge"}


@ex.named_config
def do_hpo_100():
    hpo = {
        "max_jobs": 100,
        "max_parallel_jobs": 30,
        "params": [
            {
                "epochs": "cat(500, 1000)",
                "batch_size": "cat(32, 64, 128)",
                "learning_rate": "float_log(1.0E-5, 2.0E-3)",
                "context_length_mult": "float_auto(0.3, 5.0)",
                "dim_transform": "int(20, 50)",
                "history_len": "int(24, 100)",
            }
        ],
    }
    sm = {"instance_type": "ml.c5.4xlarge"}


@ex.named_config
def do_hpo_nbeats():
    hpo = {
        "max_jobs": 200,
        "max_parallel_jobs": 20,
        "params": [
            {
                "epochs": "cat(300, 600)",
                "batch_size": "cat(256, 512, 1024)",
                "learning_rate": "float_log(1.0E-5, 2.0E-3)",
                "context_length_mult": "float_auto(0.3, 7.0)",
                "history_len": "int(24, 100)",
                "deterministic_loss": "cat('sMAPE', 'MASE', 'MAPE')"
            }
        ],
    }
    sm = {"instance_type": "ml.c5.4xlarge"}


@ex.named_config
def do_hpo_small():
    hpo = {
        "max_jobs": 200,
        "max_parallel_jobs": 30,
        "params": [
            {
                "epochs": "cat(500, 1000)",
                "batch_size": "cat(32, 64, 128)",
                "learning_rate": "float_log(1.0E-5, 2.0E-3)",
                "dim_transform": "int(20, 50)",
                "history_len": "int(24, 100)",
            }
        ],
    }
    sm = {"instance_type": "ml.c5.4xlarge"}


@ex.named_config
def yearly():
    train = "m4_yearly"
    test = ["m3_yearly", "tourism_yearly"]


@ex.named_config
def yearly_full():
    train = "m4_yearly"
    test = ["m3_yearly", "tourism_yearly", "m4_yearly"]


@ex.named_config
def quarterly():
    train = "m4_quarterly"
    test = ["m3_quarterly", "m3_other", "tourism_quarterly"]


@ex.named_config
def quarterly_full():
    train = "m4_quarterly"
    test = ["m3_quarterly", "m3_other", "tourism_quarterly", "m4_quarterly"]


@ex.named_config
def monthly():
    train = "m4_monthly"
    test = ["m3_monthly", "tourism_monthly"]
    train_prediction_length = 18  # removing this uses 24 as prediction length for train and test (can work better)


@ex.named_config
def monthly_full():
    train = "m4_monthly"
    test = ["m3_monthly", "tourism_monthly", "m4_monthly"]
    train_prediction_length = 18  # removing this uses 24 as prediction length for train and test (can work better)


@ex.named_config
def weekly_full():
    train = "m4_weekly"
    test = ["m4_weekly"]


@ex.named_config
def daily_full():
    train = "m4_daily"
    test = ["m4_daily"]

@ex.named_config
def hourly():
    train = "m4_hourly"
    test = ["electricity", "traffic"]


@ex.named_config
def hourly_full():
    train = "m4_hourly"
    test = ["electricity", "traffic", "m4_hourly"]


### model configs

@ex.named_config
def deepar():
    model_name = "deepar"


@ex.command
def meta_glar():
    model_name = "meta_glar"


def get_model_cfg(
    model_name: Optional[str],
    context_length: int,
    prediction_length: Optional[int],
    train_prediction_length: int,
    freq: str,
    history_len: int,
    batch_size: int,
    test_batch_size: int,  # only for MetaARDet
    epochs: int,
    patience: int,
    init: str,
    weight_decay: float,
    learning_rate: float,
    minimum_learning_rate: float,
    det_loss: str,
    distr_output_str: str,
    lags_seq: Optional[list],
    lags_seq_local: Optional[list],
    scaling: bool,
    dim_transform: int,
    time_features: Optional[list],
    no_age_and_time: bool,
    iterate_forecasts_during_training: bool,
    num_avg_checkpoint: int,
    residual_linear: bool,
    clip_gradient: int,
    hybridize: bool,
):
    init = "zero" if model_name == "global_linearAR" else init
    assert model_name
    trainer = Trainer(
        epochs=epochs,
        num_batches_per_epoch=50,
        learning_rate=learning_rate,  # 1e-3
        learning_rate_decay_factor=0.5,
        patience=patience,  # 20
        minimum_learning_rate=minimum_learning_rate,  # 5e-5
        batch_size=batch_size,  # BATCH SIZE IN TRAINER IS DEPRECATED
        weight_decay=weight_decay,  # 1e-8
        init=init,  # 'xavier'
        clip_gradient=clip_gradient,
        avg_strategy=SelectNLastMean(num_models=num_avg_checkpoint),
        hybridize=hybridize,
    )

    if history_len > 0:
        train_sampler = MinHistorySampler(
            min_history=history_len, num_instances=1.0
        )
    elif history_len == -1:
        train_sampler = ExpectedNumInstanceSampler(1.0)
    else:
        raise NotImplementedError("history len must be >=-1")

    if distr_output_str == "student":
        distr_output = StudentTOutput()
    elif distr_output_str == "gaussian":
        distr_output = GaussianOutput()
    else:
        raise NotImplementedError(
            "distribution", distr_output_str, "not implemented"
        )

    metaar_det_common_kwargs = dict(
        freq=freq,
        prediction_length=prediction_length,
        train_prediction_length=train_prediction_length,
        loss=det_loss,
        context_length=context_length,
        output_dim=dim_transform,
        net_context_length=None,
        # test_solver_mode="sklearn_cv",
        test_context_length=None,
        test_batch_size=1000,
        train_sampler=train_sampler,
        trainer=trainer,
        lags_seq=lags_seq,
        scaling=scaling,
        time_features=time_features,
        use_time_features=no_age_and_time,
        batch_size=batch_size,
        iterated_forecasts_during_training=iterate_forecasts_during_training,
    )

    if model_name == "global_linearAR":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
            use_shared_linear=False,
            encoder_type="identity",
        )

    elif model_name == "feedforwardAR":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
            use_shared_linear=False,
            encoder_type="feedforward",
        )

    elif model_name == "meta_glar":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
        )

    elif model_name == "rnn":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            use_shared_linear=False,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
        )

    elif model_name == "rnn_adapred":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            use_shared_linear=False,
            is_adaptive_training=False,
            is_adaptive_prediction=True,
        )

    elif model_name == "rnn_adapred_biased_reg":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            biased_regularization=True,
            use_shared_linear=False,
            is_adaptive_training=False,
            is_adaptive_prediction=True,
        )

    elif model_name == "local_linearAR":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
            use_shared_linear=False,
            encoder_type="identity",
        )

    elif model_name == "meta_linearAR":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
            use_shared_linear=False,
            encoder_type="linear",
        )

    elif model_name == "meta_linear_biased_regAR":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            biased_regularization=True,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
            use_shared_linear=False,
            encoder_type="identity",
        )

    elif model_name == "meta_feedforwardAR":
        model = MetaARDetEstimator(
            **metaar_det_common_kwargs,
            is_adaptive_training=False,
            is_adaptive_prediction=False,
            use_shared_linear=False,
            encoder_type="feedforward",
        )

    elif model_name in ["ets", "arima", "thetaf"]:
        model = DummyEstimator(
            predictor=RForecastPredictor(
                freq=freq,
                prediction_length=prediction_length,
                method_name=model_name,
                trunc_length=context_length,
            )
        )

    elif model_name == "deepar":
        model = DeepAREstimator(
            context_length=context_length,
            prediction_length=prediction_length,
            freq=freq,
            train_sampler=train_sampler,
            trainer=trainer,
            lags_seq=lags_seq,
            scaling=scaling,
            time_features=time_features,
            batch_size=batch_size,
        )

    elif model_name == "simple_feed":
        model = SimpleFeedForwardEstimator(
            freq=freq,
            distr_output=distr_output,
            prediction_length=prediction_length,
            context_length=context_length,
            train_sampler=train_sampler,
            mean_scaling=scaling,
            trainer=trainer,
            batch_size=batch_size,
        )

    elif model_name == "nbeats_sh":
        model = NBEATSEnsembleEstimator(
            freq=freq,
            prediction_length=prediction_length,
            sharing=[True],
            scale=True,
            num_stacks=1,
            num_blocks=[30],
            meta_bagging_size=3,
            train_sampler=train_sampler,
            trainer=trainer,
            batch_size=batch_size,
        )

    elif model_name == "nbeats_sh_single":
        model = NBEATSEstimator(
            freq=freq,
            prediction_length=prediction_length,
            loss_function=det_loss,
            sharing=[True],
            scale=True,
            num_stacks=1,
            num_blocks=[30],
            trainer=trainer,
            train_sampler=train_sampler,
            batch_size=batch_size,
            context_length=context_length
        )

    elif model_name == "nbeats_nsh_single":
        model = NBEATSEstimator(
            freq=freq,
            prediction_length=prediction_length,
            loss_function=det_loss,
            sharing=[False],
            scale=True,
            num_stacks=1,
            num_blocks=[30],
            trainer=trainer,
            train_sampler=train_sampler,
            batch_size=batch_size,
            context_length=context_length
        )

    else:

        raise RuntimeError("model_name={} not implemented".format(model_name))

    return to_config(model)


@ex.capture
def get_cfg(_config):
    return _config
