"""
Computes ensemble forecasts on a given directory saves the results as if they were a sacred experiment
even if sacred is not used.

forecasts for single runs, if missing, are recomputed  and saved.

Ensembles are computed over --ntop best runs according to --valmetric.
Single runs have to be passed as a list of experiment directories.
for each (model_name, iterate_forecasts_during_training) combination in the --expdir directory
Uses multiprocessing with --mp (otherwise uses the vectorization of the predictor)

Example usage:
    python run_ensemble_local.py --ntop=20 --valmetric=source_sMAPE --expdir=my_runs/meta/ my_runs/meta*
"""
from collections import defaultdict
from functools import reduce
from pathlib import Path
import multiprocessing as mp
from typing import Iterable, Sequence, Union, List, Optional, Tuple
import traceback
import random

import pandas as pd

from gluonts.time_feature import get_seasonality
from gluonts.evaluation import Evaluator
from gluonts.support.util import maybe_len
from gluonts.model.forecast import SampleForecast, Forecast, QuantileForecast
from mang import load_experiments
from gluonts.meta_tools import get_forecasts_and_series
from joblib import Parallel, delayed

from gluonts.model.predictor import Predictor, SimpleParallelizedPredictor

from mang import ExpInfo
from src.utils import get_dataset, transform_dataset
import json
import datetime
import numpy as np
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--expdir", help="Main Experiments directory",
                    default="my_runs/trial_runs/",
                    # required=False,
                    type=str)
parser.add_argument('exps', metavar='experiment', type=Path, nargs='+', help='experiments to use')
parser.add_argument("--ntop", help="Cardinality of ensemble",
                    default=20, required=False,
                    type=int)
parser.add_argument("--ntot", help="Total number of runs to consider",
                    default=200, required=False,
                    type=int)
parser.add_argument("--ntrain", help="Number of examples used to compute the train metric",
                    default=8000, required=False,
                    type=int)
parser.add_argument("--valmetric", help="Metric for the selection",
                    default="val_sMAPE", required=False,  # avg_last_losses, val_sMAPE, source_sMAPE
                    type=str)
parser.add_argument("--model", help="Model to consider for the ensemble if empty considers all",
                    default="", required=False,  # meta_deepar, deepar
                    type=str)
parser.set_defaults(multiprocessing=False, help="Multiprocessing to compute forecasts")
parser.add_argument('--mp', dest='multiprocessing', action='store_true')

config_columns = ["run_id", "model_name", "context_length_mult", "dim_transform", "history_len", "learning_rate",
                  "epochs", "batch_size", "iterate_forecasts_during_training", "n_lags_until", "source"]

source_target_dicts = {"m4_hourly": ["m4_hourly", "electricity", "traffic"],
                       "m4_monthly": ["m3_monthly",
                                      "m4_monthly",
                                      "tourism_monthly",
                                      ],
                       "m4_quarterly": ["m3_quarterly",
                                        "m4_quarterly",
                                        "m3_other",
                                        "tourism_quarterly",
                                        ],
                       "m4_yearly": ["m3_yearly",
                                     "m4_yearly",
                                     "tourism_yearly"
                                     ]
                       }

dataset_cache = {}


def compute_forecasts(ds_meta, ds, predictor, num_samples, use_multiprocessing=False, num_workers=None):
    if num_workers and use_multiprocessing:
        predictor = SimpleParallelizedPredictor(
            predictor,
            num_workers=num_workers,
            chunk_size=1000
        )

    forecasts, series = get_forecasts_and_series(
        ds,
        predictor,
        prediction_length=ds_meta.prediction_length,
        num_samples=num_samples,
    )

    return forecasts, series


def load_or_recompute_fcst_and_ts(ex, dataset, multiprocessing, train=False, limit_examples=None):
    for i in range(3):
        try:
            return get_ts_and_forecast(forecasts_path=ex.path, dataset_name=dataset, config=ex.config, train=train,
                                       limit_examples=limit_examples, multiprocessing=multiprocessing)
        except Exception:
            traceback.print_exc()
            print(f"forecasts for {dataset} not present")
            pred_path = ex.path / "predictor.tar.gz"
            predictor = Predictor.read_tgz(pred_path)
            cfg = ex.config
            evaluation = cfg["evaluation"]
            seed = cfg["seed"]

            limit_examples = evaluation["limit_examples"] if limit_examples is None else limit_examples
            transform_dataset_mode = cfg["transform_dataset_mode"]

            print(f"Loading {dataset}")
            ds_name, meta, ds = load_and_transform_dataset(dataset, transform_dataset_mode, limit_examples=limit_examples, train=train)

            num_eval_workers = 0 if not multiprocessing else mp.cpu_count()

            model_name = cfg["model_name"]
            num_samples = 1 if model_name not in ["deepar", "simple_feed"] and "prob" not in model_name else 400

            print(f"Predicting {dataset} using prediction length {meta.prediction_length}")
            forecasts, series = compute_forecasts(meta, ds, predictor, num_samples=num_samples,
                                                  use_multiprocessing=multiprocessing, num_workers=num_eval_workers)

            evaluator = Evaluator(quantiles=evaluation["quantiles"], num_workers=num_eval_workers)
            agg_metrics, item_metrics = evaluator(
                series, forecasts, num_series=maybe_len(ds)
            )

            save_results(ex.path, dataset, forecasts, item_metrics, agg_metrics, train=train, limit_examples=limit_examples)


def get_config_df_for_ensamble(exps_dict):
    rows = []
    for k, exp in exps_dict.items():
        more_rows = exp.data_rows
        for r in more_rows:
            r['run_id'] = k

        rows += more_rows

    df = pd.DataFrame(rows)

    return df[config_columns].drop_duplicates()


def ensure_forecasts(exp_path_list):
    all_exps = load_experiments(include_dummy=False, path=exp_path_list, glob_pat=None)
    all_exps = {run_id: ex for run_id, ex in all_exps.items() if ex.run_info['status'] == 'COMPLETED'}
    all_exps = {run_id: ex for run_id, ex in all_exps.items() if "Ensemble" not in run_id}

    exps = [ex for run_id, ex in all_exps.items()]

    for i, e in enumerate(exps, start=1):
        print(20 * "=" + f"{i} / {len(exps)}" + 20 * "=")
        source_dataset = e.config["train"]
        for dataset in source_target_dicts[source_dataset]:
            if (e.path / f"forecasts_{dataset}.parquet").exists():
                print(f"skipping {e}, {dataset}")
                continue
            print(f"generating {e}, {dataset}")
            _ = load_or_recompute_fcst_and_ts(e, dataset, multiprocessing=False)


def ensemble_exp(all_exps, source_dataset, exps_dir, name=None, multiprocessing=None):

    exps = [ex for run_id, ex in all_exps.items()]

    cfg_0 = exps[0].config

    num_eval_workers = 0 if not multiprocessing else mp.cpu_count()

    model_name = cfg_0["model_name"]
    transform_dataset_mode = cfg_0["transform_dataset_mode"]
    assert all([model_name == e.config["model_name"] for e in exps])
    assert all([transform_dataset_mode == e.config["transform_dataset_mode"] for e in exps])

    timestamp = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S%f")
    exp_name = "Ensemble-"
    exp_name += f"{model_name}-{source_dataset}" if name is None else name
    exp_path = Path(exps_dir) / f"{exp_name}"
    os.mkdir(exp_path)

    cfg_0["run_id"] = exp_name

    with open(exp_path / "config.json", "w") as f:
        json.dump(cfg_0, f)

    with open(exp_path / "run.json", "w") as f:
        json.dump(exps[0].run_info, f)

    config_df = get_config_df_for_ensamble(all_exps)
    config_df.to_csv(exp_path / "element_configs.csv")

    for dataset in source_target_dicts[source_dataset]:

        series, forecasts = get_ensemble_fcst(exps, dataset, multiprocessing=multiprocessing)
        evaluator = Evaluator(quantiles=[0.5], num_workers=num_eval_workers)
        agg_metrics, item_metrics = evaluator(
            series, forecasts, num_series=maybe_len(series)
        )
        save_results(exp_path, dataset, forecasts, item_metrics, agg_metrics)


def get_ensemble_fcst(exps: List[ExpInfo], dataset: str, method='median', multiprocessing=None):
    series = []
    forecasts = []
    for i, e in enumerate(exps):
        ts_and_fcst = load_or_recompute_fcst_and_ts(e, dataset, multiprocessing)
        if series:
            assert len(series) == len(ts_and_fcst)
        for k, (ts, fc) in enumerate(ts_and_fcst):
            if i == 0:
                series.append(ts)
                forecasts.append([fc])
            else:
                assert (series[k] == ts).all()
                forecasts[k].append(fc)

    ensemble_forecasts = []
    for ts, fcsts in zip(series, forecasts):
        if method == 'median':
            fc_arr = np.asarray([fc.quantile(0.5) for fc in fcsts], dtype=np.float32)
            fc0 = fcsts[0]
            assert fc_arr.ndim == 2
            fc_med = QuantileForecast(
                np.median(fc_arr, axis=0)[None, :],
                start_date=fc0.start_date,
                freq=fc0.freq,
                forecast_keys=["0.5"],
                item_id=fc0.item_id,
            )
            ensemble_forecasts.append(fc_med)
        else:
            raise NotImplementedError()

    return series, ensemble_forecasts


#### UTILS

def load_and_transform_dataset(dataset, transform_dataset_mode, limit_examples, train):
    dataset_str = dataset + str(transform_dataset_mode) + str(limit_examples) + str(train)
    if dataset_str in dataset_cache.keys():
        return dataset_cache[dataset_str]

    tmp = get_dataset(dataset, regenerate=False)
    ds = list(tmp.test) if not train else list(tmp.train)
    meta = tmp.metadata
    if limit_examples:
        print("reducing dataset size")
        ds = ds[: min(limit_examples, len(ds))]

    periodicity = get_seasonality(meta.freq)

    if transform_dataset_mode is not None:
        print(f"Transforming {dataset}: mode={transform_dataset_mode}")
        ds = transform_dataset(
                ds,
                prediction_length=meta.prediction_length,
                lead_time=0,
                periodicity=periodicity,
                mode=transform_dataset_mode,
            )
    dataset_cache[dataset_str] = (dataset, meta, ds)

    return dataset, meta, ds


def log_metric(metric_name, value):
    print(f"gluonts[metric-{metric_name}]: {value}")


def write_forecasts_parquet(
    forecasts: Iterable[Forecast],
    path: Union[Path, str],
    quantiles: Sequence[str] = ("0.5",),
) -> None:
    """
    Write forecasts as parquet files
    """
    rows = []
    for fcst in forecasts:
        rec = {
            "fcst_start": pd.Timestamp(fcst.start_date),
            "mean": fcst.mean,
            "freq": str(fcst.freq),
            "item_id": fcst.item_id,
        }
        for q in quantiles:
            rec[f"p_{q}"] = fcst.quantile(q)
        rows.append(rec)
    df = pd.DataFrame(rows)
    df.to_parquet(str(path), compression="zstd")


def save_results(parent_path, ds_name, forecasts, item_metrics, agg_metrics, train=False, limit_examples=None):
    ds_name = get_dataset_name(ds_name, train, limit_examples)

    forecasts_parquet = parent_path / f"forecasts_{ds_name}.parquet"
    write_forecasts_parquet(forecasts, forecasts_parquet)

    item_metrics_parquet = parent_path / f"item_metrics_{ds_name}.parquet"
    item_metrics.to_parquet(item_metrics_parquet, compression="zstd")

    try:
        with open(parent_path / "metrics.json") as f:
            j = json.load(f)
    except:
        print("metric file not found, a new one will be created")
        j = {}

    for metric_name, v in agg_metrics.items():
        m = f"{ds_name}_{metric_name}"
        log_metric(m, v)
        j[m] = {'steps': [0],
                'timestamps': [datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%f")],
                'values': [v],
                }

    with open(parent_path / "metrics.json", "w") as f:
        json.dump(j, f)

    df = pd.DataFrame({"value": agg_metrics})
    df.index.name = "metric"
    agg_metrics_path = parent_path / f"./agg_metrics_{ds_name}.csv"
    df.to_csv(agg_metrics_path)


def run(exps_dir, exp_path_list, n_top, n_tot, val_metric, multiprocessing, n_train, model_str):
    all_exps = load_experiments(include_dummy=False, path=exp_path_list, glob_pat=None)
    all_exps = {run_id: ex for run_id, ex in all_exps.items() if ex.run_info['status'] == 'COMPLETED'}
    all_exps = {run_id: ex for run_id, ex in all_exps.items() if "Ensemble" not in run_id}

    rows = []
    for k, exp in all_exps.items():
        source = exp.config["train"]

        num_eval_workers = 0 if not multiprocessing else mp.cpu_count()

        tr_name = get_dataset_name(source, True, n_train)
        if "train" in val_metric and exp.metrics.get(f"{tr_name}_sMAPE") is None:
            print("computing training metric")
            results = load_or_recompute_fcst_and_ts(exp, source, multiprocessing, train=True,
                                                              limit_examples=n_train)
            series = [r[0] for r in results]
            forecasts = [r[1] for r in results]
            evaluator = Evaluator(quantiles=[0.5], num_workers=num_eval_workers)
            agg_metrics, item_metrics = evaluator(
                series, forecasts, num_series=maybe_len(series)
            )
            save_results(exp.path, source, forecasts, item_metrics, agg_metrics, train=True, limit_examples=n_train)


        more_rows = exp.data_rows
        for r in more_rows:
            r['run_id'] = k

            n_mean = 5
            lcm = exp.learning_curves_metrics


            for m in ["MASE", "sMAPE", "ND"]:
                r[f"source_{m}"] = exp.metrics.get(

                    f"{source}_{m}"
                )

                if "train" in val_metric:
                    r[f"train_{m}"] = exp.metrics.get(

                        f"{tr_name}_{m}"
                    )

            if lcm is not None:
                mean_loss = lcm[lcm.dataset.str.contains('train')]['epoch_loss'][-n_mean:].mean()
                last_train_mase = lcm[lcm.dataset.str.contains('train')]['MASE'].dropna().values[-1]
                last_val_mase = lcm[lcm.dataset.str.contains('val')]['MASE'].dropna().values[-1]
                r['avg_last_losses'] = mean_loss
                r['last_train_mase'] = last_train_mase
                r['last_val_mase'] = last_val_mase

        rows += more_rows

    df = pd.DataFrame(rows)

    # select best id based on val_metric
    # val_cols = ["model_name", "source"]
    val_cols = ["model_name", "iterate_forecasts_during_training", "source"]
    dfp = pd.DataFrame(df[val_cols].drop_duplicates().groupby(val_cols))[0].values
      # "avg_last_losses", "source_MASE"
    for vc in dfp:
        print(f"model_string={model_str}")
        if model_str != "" and not model_str == vc[0]:
            continue

        cond_list = [df[col] == v for col, v in zip(val_cols, vc)]
        cond = reduce(lambda x, y: x & y, cond_list)
        tmp_df = df[cond]
        run_ids_total = tmp_df["run_id"].drop_duplicates().tolist()
        random.shuffle(run_ids_total)
        run_ids_total = run_ids_total[:min(n_tot, len(run_ids_total))]
        cond_list_2 = [df["run_id"] == v for v in run_ids_total]
        cond_2 = reduce(lambda x, y: x | y, cond_list_2)
        tmp_df = tmp_df[cond_2].sort_values(val_metric)
        run_ids = tmp_df["run_id"].drop_duplicates()[:n_top].tolist()

        for c in config_columns:
            print(f"{c}: {tmp_df[c].drop_duplicates().sort_values().values}")

        print(f"{n_top} best ({val_metric}) run ids among {n_tot_true} total runs :")
        print(run_ids)
        val_metric_str = val_metric if not val_metric.startswith("train") else f"{val_metric}N{n_train}"
        exp_name = f"{vc[0]}-IT{vc[1]}-S{vc[2]}-M{val_metric_str}-N{n_top}-TOT{n_tot_true}"
        try:
            ensemble_exp({run_id:all_exps[run_id] for run_id in run_ids}, vc[2],
                         exps_dir=exps_dir, name=exp_name, multiprocessing=multiprocessing,)
        except:
            traceback.print_exc()
            return

def main():
    args = parser.parse_args()
    # print(args)
    run(exps_dir=args.expdir, n_top=args.ntop, n_tot=args.ntot, val_metric=args.valmetric, multiprocessing=args.multiprocessing)


## UTILS
def load_fcst_and_ts(dataset_name, forecasts_path, config, train, limit_examples=None):
        from gluonts.dataset.repository import datasets as gluonts_datasets
        import pandas as pd
        import numpy as np
        p = forecasts_path / f"forecasts_{get_dataset_name(dataset_name, train, limit_examples)}.parquet"
        fcst_df = pd.read_parquet(p)
        ds_name, meta, ds = load_and_transform_dataset(dataset_name, config.get("transform_dataset_mode"),
                                                      limit_examples=limit_examples, train=train)

        ds = np.array([ts for ts in ds])
        freq = meta.freq
        assert len(ds) == len(fcst_df)
        return ds, fcst_df, freq


def get_dataset_name(dataset_name, train, limit_examples=None):
    limit_examples_string = "" if limit_examples is None else f"_{limit_examples}"
    train_string = "" if not train else f"train_"
    return f"{train_string}{dataset_name}{limit_examples_string}"


ts_dict = {}


def get_ts_and_forecast(
        forecasts_path: str,
        dataset_name: str,
        config: dict,
        indices: Optional[List[int]] = None,
        train: Optional[bool] = False,
        limit_examples: Optional[int] = None,
        multiprocessing: Optional[bool] = False,
    ) -> Optional[
        List[Tuple["pandas.Series", "gluonts.model.forecast.Forecast"]]
    ]:
        from gluonts.model.forecast import QuantileForecast
        import pandas as pd
        import numpy as np

        ds, fcst_df, freq = load_fcst_and_ts(dataset_name, forecasts_path=forecasts_path, config=config, train=train,
                                             limit_examples=limit_examples)

        if indices is not None:
            fcst_entries = fcst_df.iloc[indices]
            ts_entries = ds[indices]
        else:
            fcst_entries = fcst_df
            ts_entries = ds

        assert len(ts_entries) == len(fcst_entries)

        result = []

        dataset_key = get_dataset_name(dataset_name, train, limit_examples)


        if dataset_name not in ts_dict:
            ts_dict[dataset_key] = {}
            print("processing of ts and forecasts loop (not cached)")
        else:
            print("processing of ts and forecasts (cached)")

        ts_proc = ts_dict[dataset_key]

        # The following parallelization does not help for now.
        # def get_one_series_forecasts_tuple(i, ts_row, fcst_row):
        #     ts_item_id = str(ts_row.get("item_id"))
        #     fcst_item_id = str(fcst_row.get("item_id"))
        #     assert (
        #             ts_item_id == fcst_item_id
        #     ), f"{ts_item_id} == {fcst_item_id}"
        #     #
        #     if i in ts_proc.keys():
        #         ts = ts_proc[i]
        #     else:
        #         target = ts_row["target"]
        #         ts_index = pd.date_range(
        #             start=ts_row["start"], periods=len(target), freq=freq
        #         )
        #         ts = pd.Series(index=ts_index, data=target)
        #
        #     ar = np.stack([fcst_row["mean"], fcst_row["p_0.5"]])
        #     assert ar.shape[0] == 2
        #     fcst = QuantileForecast(
        #         ar,
        #         start_date=fcst_row["fcst_start"],
        #         freq=freq,
        #         forecast_keys=["mean", "0.5"],
        #         item_id=fcst_row.get("item_id"),
        #     )
        #     return (ts, fcst)
        #
        # n_workers = mp.cpu_count() if multiprocessing else 1
        # results = Parallel(n_jobs=n_workers)(
        #     delayed(get_one_series_forecasts_tuple)(i, ts_row, fcst_row)
        #     for i, (ts_row, fcst_row) in enumerate(zip(
        #     ts_entries, fcst_entries.to_dict(orient="records")
        # ))
        # )
        # ts_proc = {i: r[0] for i, r in enumerate(results)}
        #
        # return results


        #
        for i, (ts_row, fcst_row) in enumerate(zip(
            ts_entries, fcst_entries.to_dict(orient="records")
        )):
            # print(ts_row.get('item_id'), fcst_row.get('item_id'))
            ts_item_id = str(ts_row.get("item_id"))
            fcst_item_id = str(fcst_row.get("item_id"))
            assert (
                    ts_item_id == fcst_item_id
            ), f"{ts_item_id} == {fcst_item_id}"

            if i in ts_proc.keys():
                ts = ts_proc[i]
            else:
                target = ts_row["target"]
                ts_index = pd.date_range(
                    start=ts_row["start"], periods=len(target), freq=freq
                )
                ts = pd.Series(index=ts_index, data=target)
                ts_proc[i] = ts

            ar = np.stack([fcst_row["mean"], fcst_row["p_0.5"]])
            assert ar.shape[0] == 2
            fcst = QuantileForecast(
                ar,
                start_date=fcst_row["fcst_start"],
                freq=freq,
                forecast_keys=["mean", "0.5"],
                item_id=fcst_row.get("item_id"),
            )
            result.append((ts, fcst))

        return result


if __name__ == '__main__':
    args = parser.parse_args()
    print(args)
    run(
        exps_dir=args.expdir,
        exp_path_list=args.exps, n_top=args.ntop, n_tot=args.ntot, val_metric=args.valmetric, multiprocessing=args.multiprocessing,
        n_train=args.ntrain, model_str=args.model)