#!/usr/bin/env python
# coding: utf-8
import re
import shutil
import sys
import subprocess
from pathlib import Path
from typing import Dict, Optional, List, Tuple, Union
import json
import numpy as np
from tqdm import tqdm
import multiprocessing as mp

from gluonts import transform
from gluonts.dataset.common import Dataset
from gluonts.time_feature import get_seasonality
from gluonts.transform import TransformedDataset


BASE_PATH = Path(__file__).parent.absolute() / "my_runs/meta/"


def transform_dataset(
    dataset: Dataset,
    prediction_length=0,
    lead_time=0,
    periodicity=1,
    mode="MASE_scaling",
):
    def transform_target(data):
        data = data.copy()
        target = data["target"]
        if mode == "MASE_scaling":
            trunc_target = (
                target[: -prediction_length - lead_time]
                if prediction_length + lead_time != 0
                else target
            )
            seasonal_error = np.mean(
                np.abs(
                    trunc_target[periodicity:] - trunc_target[:-periodicity]
                )
            )
            seasonal_error_no_zero = (
                seasonal_error + (seasonal_error == 0) * 1e-4
            )
            data["target"] = target / seasonal_error_no_zero
        else:
            raise NotImplementedError("mode:", mode, "not implemented!")

        return data

    dataset_transformed = TransformedDataset(
        dataset, transformation=transform.AdhocTransform(transform_target)
    )

    return dataset_transformed


def pretty_date(time):
    """
    Get a datetime object or a int() Epoch timestamp and return a
    pretty string like 'an hour ago', 'Yesterday', '3 months ago',
    'just now', etc
    """
    from datetime import datetime
    import pandas as pd
    import numpy as np

    now = pd.Timestamp.now(tz="utc")
    if isinstance(time, str):
        time = pd.Timestamp(time, tz="utc")

    if type(time) is int:
        diff = now - datetime.fromtimestamp(time)
    elif isinstance(time, (datetime, pd.Timestamp)):
        diff = now - time
    elif not time or np.isnan(time):
        return None
    else:
        raise RuntimeError()
    second_diff = diff.seconds
    day_diff = diff.days

    if day_diff < 0:
        return ""

    if day_diff == 0:
        if second_diff < 10:
            return "just now"
        if second_diff < 60:
            return str(second_diff) + " seconds ago"
        if second_diff < 120:
            return "a minute ago"
        if second_diff < 3600:
            return str(second_diff // 60) + " minutes ago"
        if second_diff < 7200:
            return "an hour ago"
        if second_diff < 86400:
            return str(second_diff // 3600) + " hours ago"
    if day_diff == 1:
        return "Yesterday"
    if day_diff < 7:
        return str(day_diff) + " days ago"
    if day_diff < 31:
        return str(day_diff // 7) + " weeks ago"
    if day_diff < 365:
        return str(day_diff / 30) + " months ago"
    return str(day_diff / 365) + " years ago"


def flatten_dict(config):
    flat_cfg = {}

    def flat(cfg, prev):
        if isinstance(cfg, dict):
            for k, v in cfg.items():
                flat(v, prev + [k])
        else:
            flat_cfg[".".join(prev)] = cfg

    flat(config, [])
    return flat_cfg


class ExpInfo:
    def __init__(
        self,
        path: Path,
        run_id: str,
        config: Dict,
        run_info: Dict,
        metrics: Dict,
        learning_curves_metrics: Optional["pandas.DataFrame"],
    ):
        self.path = path
        self.run_id = run_id
        self.config = config
        self.run_info = run_info
        self.metrics = metrics
        self.learning_curves_metrics = learning_curves_metrics
        self._cache = None

    @classmethod
    def load_from_path(clf, path: Path) -> Optional["ExpInfo"]:
        import pandas as pd

        run_id = str(path).split("/")[-1]

        def read_json(file_name):
            fp = path / file_name
            if not fp.is_file():
                return None
            with fp.open("r") as f:
                s = f.read()
                return json.loads(s)

        config = read_json("config.json")
        run_info = read_json("run.json")

        if not config or not run_info:
            return None

        with open(path / "metrics.json") as f:
            j = json.load(f)

        metrics = {}
        for k, v in j.items():
            values = v["values"]
            assert len(values) == 1
            metrics[k] = values[0]

        learning_curves_csv = path / "learning_curves_metrics.csv"
        if learning_curves_csv.is_file():
            learning_curves_metrics = pd.read_csv(learning_curves_csv)
        else:
            learning_curves_metrics = None

        return ExpInfo(
            path, run_id, config, run_info, metrics, learning_curves_metrics
        )

    def get_metric_value(self, metric: str) -> Optional[float]:
        return self.metrics.get(metric)

    @property
    def data_rows(self):
        row = flatten_dict(self.config)
        row.update(flatten_dict({"run_info": self.run_info}))
        row["run_id"] = self.run_id
        row["path"] = self.path

        train = self.config["train"]
        # test = self.config["test"]
        test = [
            re.match(r'.*metrics_(.*)\.csv', str(p)).group(1)
            for p in self.path.glob('agg_metrics_*.csv')
        ]
        row["source"] = train

        result = []
        for target in test:
            r = row.copy()
            r["target"] = target
            for metric_name in ["MASE", "MAPE", "sMAPE", "ND"]:
                r[metric_name] = self.metrics.get(f"{target}_{metric_name}")
                r[f"val_{metric_name}"] = self.metrics.get(
                    f"val_{metric_name}"
                )

            result.append(r)
        return result

    @property
    def item_metrics(self) -> Optional["pandas.DataFrame"]:
        import pandas as pd

        p = self.path / "item_metrics.parquet"
        if p.is_file():
            return pd.read_parquet(p)
        else:
            return None

    def load_fcst_and_ts(self, dataset_name):
        from gluonts.dataset.repository import datasets as gluonts_datasets
        import pandas as pd
        import numpy as np

        p = self.path / f"forecasts_{dataset_name}.parquet"
        fcst_df = pd.read_parquet(p)
        meta, _, test_ds = gluonts_datasets.get_dataset(dataset_name)

        # transform the dataset if it was done during training
        periodicity = get_seasonality(self.config["freq"])

        transform_dataset_mode = self.config.get("transform_dataset_mode")
        if transform_dataset_mode is not None:
            print(f"Transforming dataset: mode={transform_dataset_mode}")
            test_ds = transform_dataset(
                test_ds,
                prediction_length=self.config["prediction_length"],
                lead_time=0,
                periodicity=periodicity,
                mode=transform_dataset_mode,
            )

        test_ds = np.array([ts for ts in test_ds])
        freq = meta.freq
        assert len(test_ds) == len(fcst_df)
        return test_ds, fcst_df, freq

    def get_ts_and_forecast(
        self,
        dataset_name: str,
        indices: Optional[List[int]] = None,
        do_cache=False,
        do_cache_dataset=False,
    ) -> Optional[
        List[Tuple["pandas.Series", "gluonts.model.forecast.Forecast"]]
    ]:
        from gluonts.model.forecast import QuantileForecast
        import pandas as pd
        import numpy as np

        if self._cache is None and do_cache:
            self._cache = self.load_fcst_and_ts(dataset_name)

        if self._cache:
            test_ds, fcst_df, freq = self._cache
        else:
            test_ds, fcst_df, freq = self.load_fcst_and_ts(dataset_name)

        if indices is not None:
            fcst_entries = fcst_df.iloc[indices]
            ts_entries = test_ds[indices]
        else:
            fcst_entries = fcst_df
            ts_entries = test_ds

        assert len(ts_entries) == len(fcst_entries)

        result = []
        for ts_row, fcst_row in zip(
            ts_entries, fcst_entries.to_dict(orient="records")
        ):
            # print(ts_row.get('item_id'), fcst_row.get('item_id'))
            ts_item_id = str(ts_row.get("item_id"))
            fcst_item_id = str(fcst_row.get("item_id"))
            assert (
                ts_item_id == fcst_item_id
            ), f"{ts_item_id} == {fcst_item_id}"
            target = ts_row["target"]
            ts_index = pd.date_range(
                start=ts_row["start"], periods=len(target), freq=freq
            )
            ts = pd.Series(index=ts_index, data=target)
            ar = np.stack([fcst_row["mean"], fcst_row["p_0.5"]])
            assert ar.shape[0] == 2
            fcst = QuantileForecast(
                ar,
                start_date=fcst_row["fcst_start"],
                freq=freq,
                forecast_keys=["mean", "0.5"],
                item_id=fcst_row.get("item_id"),
            )
            result.append((ts, fcst))
        return result


def diff_config_and_metrics(
    exps: List[ExpInfo], metrics=["sMAPE", "MASE", "MAPE"]
):
    import pandas as pd

    cs = [flatten_dict(exp.config) for exp in exps]
    keys = []
    for c in cs:
        keys.extend(c.keys())
    diff = []
    for k in set(keys):
        vs = [c.get(k) for c in cs]
        if vs[1:] == vs[:-1]:
            continue
        d = {"key": k}
        d.update({e.run_id: vi for e, vi in zip(exps, vs)})
        diff.append(d)
    diff.sort(
        key=lambda r: [len(r["key"].split(".")) != 1] + r["key"].split(".")
    )

    metric_diff = []
    for m in metrics:
        d = {"key": m}
        for e in exps:
            d[e.run_id] = e.get_metric_value(m)
        metric_diff.append(d)

    return pd.DataFrame(metric_diff + diff)


def compare_forecasts(
    exp1, exp2, metric="sMAPE", label_config=["model_name", "run_id"]
):
    import numpy as np
    import plotly.graph_objs as go
    import plotly.offline as py
    from ipywidgets import interactive, HBox, VBox

    py.init_notebook_mode()

    label1 = ", ".join([exp1.data_rows[l].split("-")[0] for l in label_config])
    label2 = ", ".join([exp2.data_rows[l].split("-")[0] for l in label_config])

    im1 = exp1.item_metrics
    im2 = exp2.item_metrics
    assert np.all(im1.item_id == im2.item_id)

    m1 = im1[metric]
    m2 = im2[metric]
    max_val = max(m1.max(), m2.max())

    ts, fcst1 = exp1.get_ts_and_forecast([1], do_cache=True)[0]
    ts_, fcst2 = exp2.get_ts_and_forecast([1], do_cache=True)[0]
    assert np.all(ts == ts_)

    f = go.FigureWidget(
        [
            go.Scatter(x=ts.index, y=ts.values, name="target"),
            go.Scatter(
                x=fcst1.index,
                y=fcst1.quantile(0.5),
                name=label1,
            ),
            go.Scatter(
                x=fcst2.index,
                y=fcst2.quantile(0.5),
                name=label2,
            ),
        ]
    )

    scat = go.FigureWidget(
        [
            go.Scatter(
                x=m1,
                y=m2,
                mode="markers",
                text=im1.item_id,
                # marker=...
            ),
            go.Scatter(x=[0, max_val], y=[0, max_val]),
        ],
    )
    scat.layout.xaxis.title = label1
    scat.layout.yaxis.title = label2
    scat.layout.width = 600
    scat.layout.height = 600

    f.layout.legend = dict(
        x=0.02,
        y=0.98,
        traceorder="normal",
        font=dict(
            size=12,
        ),
        bgcolor="rgba(255,255,255,0.5)",
    )

    def update(s, points, input_state):
        i = points.point_inds[0]
        ts, fcst1 = exp1.get_ts_and_forecast([i], do_cache=True)[0]
        ts_, fcst2 = exp2.get_ts_and_forecast([i], do_cache=True)[0]
        assert np.all(ts == ts_)

        f.data[0].x = ts.index
        f.data[0].y = ts.values
        f.data[1].x = fcst1.index
        f.data[1].y = fcst1.quantile(0.5)
        f.data[2].x = fcst2.index
        f.data[2].y = fcst2.quantile(0.5)

        f.update_layout()

    scat.data[0].on_hover(update)
    return HBox([scat, f])


def load_literature_results() -> List[Dict]:
    import pandas as pd
    import os

    dir_path = os.path.dirname(os.path.realpath(__file__))
    df: pd.DataFrame = pd.read_excel(f"{dir_path}/literature_results.xlsx")
    df["model_name"] = df["model_name"] + " (lit)"
    df["source"] = df["train"]
    df["target"] = df["test"]
    del df["train"]
    del df["test"]
    recs = df.to_dict(orient="records")
    return recs


def load_experiments(
    path: Union[Path, List[Path]] = BASE_PATH, glob_pat="*", include_dummy=False,
) -> Dict[str, ExpInfo]:
    """
    If a single path is given use the glob pattern to get all paths for loading experiments.
    If a list of paths is given load these experiments (glob_pat should be None in this case).
    """
    if isinstance(path, Path):
        assert glob_pat is not None
        path_list = [p for p in path.glob(glob_pat)]
    else:
        assert glob_pat is None
        path_list = path

    exps = []
    for p in path_list:
        try:
            e = ExpInfo.load_from_path(p)
        except FileNotFoundError as ex:
            print(f"Error loading {p} skipping this: {ex}")
            continue

        if e is None:
            continue
        if not include_dummy and e.config["is_dummy_run"]:
            continue

        git = e.config.get("git")
        if git:
            git["commit"] = git["commit"][:7]
        exps.append(e)

    return {exp_info.run_id: exp_info for exp_info in exps}


import click


@click.group()
def cli():
    """cli tool for showing and analyzing experiment runs."""
    pass


@cli.command()
@click.option(
    "--dummy", help="include dummy runs", is_flag=True, default=False
)
@click.option(
    "--filter",
    help="filter by one or more fields. e.g. --filter=status=RUNNING",
    default=None,
)
@click.option(
    "--sort",
    help="sort by one or more fields. e.g. --sort=run_id",
    default=None,
)
@click.option(
    "--col",
    help="include additional columns. e.g. --col=epochs,learning_rat",
    default=None,
)
@click.option(
    "--group",
    help="group by one or more fields. e.g. --group=train,test ",
    default=None,
)
@click.option(
    "--lit", help="include literature results", is_flag=True, default=False
)
def list(dummy, filter, sort, col, group, lit):
    """List runs"""
    import pandas as pd
    import numpy as np
    from tabulate import tabulate

    exps = load_experiments(include_dummy=dummy)

    cols = [
        ("user", None),
        ("run_id", None),
        ("run_info.start_time", "start"),
        ("run_info.status", "status"),
        ("model_name", None),
        ("ensemble", None),
        ("source", None),
        ("target", None),
        ("duration [min]", None),
        ("sMAPE", None),
        ("MASE", None),
        ("ND", None),
    ]

    if col is not None:
        cols += [(c, None) for c in col.split(",")]

    rows = []
    for exp_info in exps.values():
        rows += exp_info.data_rows

    if lit:
        rows.extend(load_literature_results())

    df = pd.DataFrame(rows)
    stop_time = pd.to_datetime(df["run_info.stop_time"])
    heartbeat = pd.to_datetime(df["run_info.heartbeat"])
    start_time = pd.to_datetime(df["run_info.start_time"])
    df["duration [min]"] = (
        (stop_time.combine_first(heartbeat) - start_time)
        / np.timedelta64(60, "s")
    ).round()
    df: pd.DataFrame = df[[c[0] for c in cols]]
    mapper = {c: v for c, v in cols if v}
    df = df.rename(columns=mapper)
    if sort is not None:
        df = df.sort_values(by=sort.split(","))
    df["start"] = df["start"].apply(lambda s: pretty_date(s))

    def parse_ensemble(e):
        if isinstance(e, List):
            for k, v in e:
                return f"{len(v)} variants"
        return e

    df["ensemble"] = df["ensemble"].apply(parse_ensemble)

    if filter is not None:
        filter_fields = [s.split("=") for s in filter.split(",")]
        mask = np.array([True] * len(df))
        for field, value in filter_fields:
            mask &= df[field].str.strip() == value.strip()
        df = df[mask]
    df = df.fillna(value="")
    if group is not None:
        group_fields = [s.strip() for s in group.split(",")]
        for vals, sub_df in df.groupby(group_fields):
            print(
                tabulate(
                    sub_df,
                    headers=df.columns,
                    showindex=False,
                    tablefmt="github",
                )
            )
            print("\n")
    else:
        print(
            tabulate(
                df, headers=df.columns, showindex=False, tablefmt="github"
            )
        )


@cli.command()
@click.argument("run_ids", nargs=-1)
def diff(run_ids):
    """Get a diff between run configurations."""
    assert len(run_ids) > 1

    from tabulate import tabulate

    exps = load_experiments(include_dummy=True)
    df = diff_config_and_metrics([exps[r] for r in run_ids])
    print(tabulate(df, headers=df.columns, showindex=False, tablefmt="github"))


@cli.command()
@click.argument("run_id")
def show(run_id):
    """Show a run."""
    exps = load_experiments(include_dummy=True, glob_pat=run_id)
    assert len(exps) == 1
    exp = [e for e in exps.values()][0]
    print(json.dumps(exp.config, indent=2))


@cli.command()
@click.argument("run_ids", nargs=-1)
def ax(run_ids):
    """Archive runs"""
    run_ids = [r.strip() for r in run_ids]
    assert run_ids
    for r in run_ids:
        assert r
    import isengard

    client = isengard.Client()

    # Fill in the following variables for your isengard profile
    region = ""
    account = ""
    role = ""
    bucket_name = ""

    experiment_folder = "experiments/meta"
    experiment_archive_folder = "experiments/meta_archive"

    sess = client.get_boto3_session(account, role, region=region)
    s3 = sess.resource("s3")
    bucket = s3.Bucket(bucket_name)

    for run_id in run_ids:
        if any(bucket.objects.filter(Prefix=f"{experiment_folder}/{run_id}/")):
            print("found")
            cmd = f"aws s3 --profile=mlf-bench mv --recursive s3://{bucket_name}/{experiment_folder}/{run_id} s3://{bucket_name}/{experiment_archive_folder}/{run_id}"
            print(f"running {cmd}")
            ret = subprocess.run([cmd], shell=True, check=True)
        else:
            print(f"Did not find run on s3: {run_id}")
        p = Path(f"./my_runs/meta/{run_id}")
        if p.is_dir():
            new_p = Path(f"./my_runs/meta_archive/{run_id}")
            print(f"moving local run: {p} -> {new_p}")
            shutil.move(p, new_p)
        if p.exists():
            shutil.rmtree(p)

    # def do_archive(old_key):
    #     s3_sess = sess.resource("s3")
    #     assert o.key.startswith(experiment_folder)
    #     new_key = experiment_archive_folder + o.key[len(experiment_folder):]
    #     s3_sess.Object(bucket_name, new_key).copy_from(CopySource=f"{bucket_name}/{old_key}")
    #     # s3_sess.Object(bucket_name, old_key).delete()
    #     return old_key, new_key
    #
    # old_keys = [o.key for key in obs]
    # with mp.Pool(10) as p:
    #     for ok, nk in tqdm(p.imap(do_archive, old_keys), total=len(old_keys), desc="moving keys"):
    #         print(f"{ok} -> {nk}")
    #
    # # for o in tqdm(obs, "moving keys"):
    # #
    # #     assert o.key.startswith(experiment_folder)
    # #     new_key = experiment_archive_folder + o.key[len(experiment_folder):]
    # #     s3.Object(bucket_name, new_key).copy_from(CopySource=f"{bucket_name}/{o.key}")
    # #     o.delete()
    # #


if __name__ == "__main__":
    cli()
