import json
import os
from pathlib import Path
from typing import Dict, Any, Optional
import re
from contextlib import contextmanager
import time

import numpy as np
from gluonts import transform
from gluonts.mx import Tensor
from gluonts.transform import TransformedDataset
from joblib import Memory
from sacred import Experiment

from gluonts.core.serde import load_json, dump_json
from gluonts.dataset.common import TrainDatasets, Dataset
from gluonts.dataset.repository import datasets as gluonts_datasets, datasets

this_file_path = Path(__file__).parent.absolute()


def get_dataset(
    dataset_name,
    regenerate: bool = False,
) -> TrainDatasets:
    datasets_dir = os.environ.get(
        "SM_CHANNEL_DATASETS", gluonts_datasets.default_dataset_path
    )
    return gluonts_datasets.get_dataset(
        dataset_name, regenerate=regenerate, path=Path(datasets_dir)
    )


def to_config(t: Any) -> Dict:
    """
    Write the object into a "config" format:
    {
      '@': 'gluonts.model.deepar._estimator.DeepAREstimator',
      'cell_type': 'lstm',
      'distr_output': {
        '@': 'gluonts.mx.distribution.student_t.StudentTOutput'
      },
      'trainer'
    }
    """
    json_blob = json.loads(dump_json(t))

    def is_class(x):
        return isinstance(x, dict) and x.get("__kind__") == "instance"

    def enc(js):
        if is_class(js):
            assert not js["args"], "Cannot encode non-kwargs"
            res = {"@": js["class"]}
            for k, v in js["kwargs"].items():
                res[k] = enc(v)
            return res
        elif isinstance(js, dict):
            assert "__kind__" not in js, f"Cannot encode kind {js['__kind__']}"
            return {k: enc(v) for k, v in js.items()}
        elif isinstance(js, list):
            return [enc(vi) for vi in js]
        elif isinstance(js, tuple):
            return tuple([enc(vi) for vi in js])
        else:
            return js

    return enc(json_blob)


def from_config(d: Dict) -> Any:
    def is_class(x):
        return isinstance(x, dict) and "@" in x

    def dec(js):
        if is_class(js):
            return {
                "__kind__": "instance",
                "args": [],
                "class": js["@"],
                "kwargs": {k: dec(v) for k, v in js.items() if k != "@"},
            }
        elif isinstance(js, dict):
            return {k: dec(v) for k, v in js.items()}
        elif isinstance(js, list):
            return [dec(vi) for vi in js]
        elif isinstance(js, tuple):
            return tuple([dec(vi) for vi in js])
        else:
            return js

    return load_json(json.dumps(dec(d)))


@contextmanager
def track_time(ex: Experiment, name: str):
    tic = time.time()
    try:
        yield
    finally:
        toc = time.time()
        print(f"{name}_time={toc - tic}")
        ex.log_scalar(f"{name}_time", toc - tic)


def get_git_commit() -> Optional[Dict]:
    try:
        from git import Repo, InvalidGitRepositoryError
    except ImportError as e:
        raise ValueError(
            "Cannot import git (pip install GitPython).\n"
            "Either GitPython or the git executable is missing.\n"
        ) from e

    try:
        repo = Repo(str(this_file_path), search_parent_directories=True)
    except InvalidGitRepositoryError:
        return None
    try:
        remote_path = repo.remote().url
    except ValueError:
        remote_path = None
    path = "git:/" + repo.working_dir
    is_dirty = repo.is_dirty()
    commit = repo.head.commit.hexsha
    return dict(
        path=path, remote_path=remote_path, commit=commit, is_dirty=is_dirty
    )


def normalize_freq(freq_str):
    m = re.match(r"^(\d*)(\D+)$", freq_str.strip())
    assert m
    d = int(m.group(1)) if m.group(1) else 1
    f = m.group(2)
    return f"{d}{f}"
    # return pd.Timedelta(f"{d}{f}")


memory = Memory(str(this_file_path / "_dataset_stat"), verbose=0)


@memory.cache
def collect_dataset_stats(dataset_name, which="train"):
    assert which in ["train", "test"]
    assert dataset_name in datasets.dataset_names
    metadata, train, test = get_dataset(dataset_name)
    if which == "train":
        ds = train
    else:
        ds = test
    lens = []
    for ts in ds:
        lens.append(len(ts["target"]))

    stat = dict(
        freq=metadata.freq,
        prediction_length=metadata.prediction_length,
        min_len=int(np.min(lens)),
        max_len=int(np.max(lens)),
        median_len=int(np.quantile(lens, 0.5)),
        avg_len=float(np.mean(lens)),
        num_ts=len(lens),
    )
    return stat


def shorten_dataset(dataset: Dataset, new_length, lead_time):
    def truncate_target(data):
        data = data.copy()
        target = data["target"]
        data["target"] = target[..., -new_length - lead_time :]
        return data

    dataset_trunc = TransformedDataset(
        dataset, transformation=transform.AdhocTransform(truncate_target)
    )

    return dataset_trunc


def transform_dataset(
    dataset: Dataset,
    prediction_length=0,
    lead_time=0,
    periodicity=1,
    mode="MASE_scaling",
):
    def transform_target(data):
        data = data.copy()
        target = data["target"]
        if mode == "MASE_scaling":
            trunc_target = (
                target[: -prediction_length - lead_time]
                if prediction_length + lead_time != 0
                else target
            )
            seasonal_error = np.mean(
                np.abs(
                    trunc_target[periodicity:] - trunc_target[:-periodicity]
                )
            )
            seasonal_error_no_zero = seasonal_error.clip(1e-3, np.inf)
            data["target"] = target / seasonal_error_no_zero
        else:
            raise NotImplementedError("mode:", mode, "not implemented!")

        return data

    dataset_transformed = TransformedDataset(
        dataset, transformation=transform.AdhocTransform(transform_target)
    )

    return dataset_transformed
