# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import json
import os
import re
import warnings
from pathlib import Path
from typing import NamedTuple

import numpy as np
import pandas as pd

from gluonts.dataset.repository._util import metadata, save_to_file, to_dict
from gluonts.gluonts_tqdm import tqdm


def check_dataset(dataset_path: Path, length: int):
    # check that things are correct
    from gluonts.dataset.common import load_datasets

    ds = load_datasets(
        metadata=dataset_path,
        train=dataset_path / "train",
        test=dataset_path / "test",
    )

    assert ds.test is not None
    assert len(list(ds.train)) == length
    assert len(list(ds.test)) == length

    assert ds.metadata.prediction_length is not None

    for ts_train, ts_test in tqdm(
        zip(ds.train, ds.test), total=length, desc="checking consistency"
    ):
        train_target = ts_train["target"]
        test_target = ts_test["target"]
        # assert (
        #     len(train_target)
        #     == len(test_target) - ds.metadata.prediction_length
        # )
        # assert np.all(train_target == test_target[: len(train_target)])
        #
        # assert ts_train["start"] == ts_test["start"]
        # start = ts_train["start"]
        # regex = r"^(\d{4})-(\d{2})-(\d{2})( 00:00(:00)?)?$"
        # m = re.match(regex, str(start))
        # assert m
        # month, day = m.group(2), m.group(3)
        # if sheet_name in ["M3Quart", "Other"]:
        #     assert f"{month}-{day}" in [
        #         "03-31",
        #         "06-30",
        #         "09-30",
        #         "12-31",
        #     ], f"Invalid time stamp `{month}-{day}`"
        # elif sheet_name == "M3Year":
        #     assert (
        #         f"{month}-{day}" == "12-31"
        #     ), f"Invalid time stamp {month}-{day}"


# class M3Setting(NamedTuple):
#     sheet_name: str
#     prediction_length: int
#     freq: str


def load_series_from_df(df, freq):
    data = []

    for series in df.columns:
        length = int(df[series].iloc[0])

        start_year = int(df[series].iloc[1])

        if freq in "monthly":
            offset = 3
            start_month = int(df[series].iloc[2])
            start = f"{start_year}-{start_month:02d}"
        elif freq in "quarterly":
            offset = 3
            start_month = int(df[series].iloc[2]) * 3
            start = f"{start_year}-{start_month:02d}"
        elif freq == "yearly":
            offset = 2
            start = f"{start_year}"
        else:
            raise RuntimeError(
                f"invalid freq {freq} should be one of ['yearly', 'quarterly', 'monthly']"
            )

        data.append(
            {
                "target": df[series].iloc[offset : length + offset].values,
                "item_id": series,
                "start": start,
            }
        )
    return data


def generate_tourism_dataset(dataset_path: Path, tourism_freq: str):
    assert tourism_freq in ["monthly", "quarterly", "yearly"]
    from gluonts.dataset.repository.datasets import default_dataset_path

    tourism_csv_path = default_dataset_path / "tourism"
    file_list = [
        "monthly_in.csv",
        "monthly_oos.csv",
        "quarterly_in.csv",
        "quarterly_oos.csv",
        "yearly_in.csv",
        "yearly_oos.csv",
    ]
    for filename in file_list:
        if not (tourism_csv_path / filename).is_file():
            raise RuntimeError(
                f"The tourism data is available at https://robjhyndman.com/data/27-3-Athanasopoulos1.zip"
                f"Please download the file and copy the {', '.join(file_list)} to this location: {tourism_csv_path}"
            )

    freq = {"monthly": "M", "quarterly": "Q", "yearly": "Y"}[tourism_freq]

    ins = pd.read_csv(tourism_csv_path / f"{tourism_freq}_in.csv")
    oos = pd.read_csv(tourism_csv_path / f"{tourism_freq}_oos.csv")

    series_in = load_series_from_df(ins, tourism_freq)
    series_oos = load_series_from_df(oos, tourism_freq)
    assert len(series_oos) == len(series_in)

    pred_length = None
    for s in series_oos:
        if pred_length is not None:
            assert pred_length == len(s["target"])
        pred_length = len(s["target"])

    train_data = []
    test_data = []
    for t, s in zip(series_in, series_oos):
        assert t["item_id"] == s["item_id"]
        t["target"] = t["target"].astype(float).tolist()
        train_data.append(t)
        ts = t.copy()
        ts["target"] = (
            np.concatenate([t["target"], s["target"]]).astype(float).tolist()
        )
        test_data.append(ts)

    os.makedirs(dataset_path, exist_ok=True)
    with open(dataset_path / "metadata.json", "w") as f:
        f.write(
            json.dumps(
                metadata(
                    cardinality=[len(test_data)],
                    freq=freq,
                    prediction_length=pred_length,
                )
            )
        )

    train_file = dataset_path / "train" / "data.json"
    test_file = dataset_path / "test" / "data.json"

    save_to_file(train_file, train_data)
    save_to_file(test_file, test_data)

    check_dataset(dataset_path, len(train_data))
