import os
from pathlib import Path
from typing import Any, Dict, Optional
import catch22
import click
import pandas as pd
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map
from tsbench.config import DATASET_REGISTRY


@click.command()
@click.argument(
    "path",
    type=click.Path(exists=True),
    nargs=1,
    default=Path.home() / "data" / "catch22",
)
@click.option("dataset", type=str, default=None)
def main(path: str, dataset: Optional[str]):
    """
    Computes the catch22 features for all datasets in the registry.
    """
    if dataset is None:
        datasets = list(DATASET_REGISTRY.items())
    else:
        datasets = [(dataset, DATASET_REGISTRY[dataset])]

    directory = Path(path)
    for dataset_name, config in tqdm(datasets):
        file = directory / f"{dataset_name}.parquet"
        if file.exists():
            continue

        ts_features = process_map(
            get_features,
            config.data.train(val=False).gluonts(),  # Get features on train set
            max_workers=os.cpu_count(),
            desc=dataset_name,
        )
        df = pd.DataFrame(ts_features)
        df.to_parquet(file)


def get_features(ts: Dict[str, Any]) -> Dict[str, Any]:
    """
    Computes the catch22 features for the given time series.
    """
    features = catch22.catch22_all(ts["target"])
    return dict(zip(features["names"], features["values"]))


if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter
    main()
