# Downloads and preprocesses LOTSA datasets in the correct format locally
import multiprocessing as mp
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm.auto import tqdm

import datasets


def process_entry(entry: dict) -> dict:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        new_entry = {}

        if "item_id" in entry:
            new_entry["id"] = entry["item_id"]

        target = entry["target"]
        start = entry["start"]
        if isinstance(start, np.ndarray):
            start = start.item()
        new_entry["timestamp"] = pd.date_range(start, freq=entry["freq"], periods=target.shape[-1]).to_numpy(
            dtype="datetime64[ms]"
        )

        if len(target.shape) == 1:
            new_entry["target"] = target
        elif len(target.shape) == 2:
            for idx, val in enumerate(target):
                new_entry[f"target_{idx}"] = val
        else:
            raise ValueError("target has >2 dimensions!")

        exog = entry.get("past_feat_dynamic_real")
        if exog is not None:
            assert len(exog.shape) == 2
            for idx, val in enumerate(exog):
                new_entry[f"exog_{idx}"] = val

        for key in entry:
            if key not in [
                "target",
                "start",
                "item_id",
                "freq",
                "past_feat_dynamic_real",
            ]:
                raise ValueError(f"Unprocessed key {key}")
        return new_entry


def save_lotsa(config_name, out_dir, root="salesforce/lotsa_data", num_proc=mp.cpu_count()):
    ds = datasets.load_dataset(root, config_name, split="train")
    ds.set_format("numpy")
    num_proc = min(num_proc, len(ds))
    processed = ds.map(process_entry, num_proc=num_proc, remove_columns=ds.column_names)
    if "id" not in processed.column_names:
        digits = len(str(len(ds)))
        processed.add_column("id", [f"T{idx:0{digits}}" for idx in range(len(processed))])
    out_path = Path(out_dir) / config_name
    out_path.mkdir(exist_ok=True, parents=True)
    processed.to_parquet(out_path / "data.parquet")
    print(f"{config_name} saved")


LOTSA_DATASETS = [
    "BEIJING_SUBWAY_30MIN",
    "HZMETRO",
    "LOS_LOOP",
    "M_DENSE",
    "PEMS03",
    "PEMS08",
    "SHMETRO",
    "SZ_TAXI",
    "bdg-2_bear",
    "bdg-2_fox",
    "bdg-2_panther",
    "bdg-2_rat",
    "beijing_air_quality",
    "borealis",
    "cdc_fluview_ilinet",
    "hierarchical_sales",
    "ideal",
    "kdd2022",
    "project_tycho",
    "smart",
    "subseasonal_precip",
]


if __name__ == "__main__":
    n_jobs = 8
    out_dir = Path("./lotsa")
    num_proc_per_task = max(mp.cpu_count() // n_jobs, 1)
    Parallel(n_jobs=n_jobs)(delayed(save_lotsa)(name, out_dir=out_dir) for name in tqdm(LOTSA_DATASETS))
