"""Shared default hyperparameters for time-series experiments."""

from __future__ import annotations

from typing import Iterable


COMMON_DEFAULTS = {
    "T": 200,
    "lambdas": (0.98, 0.94, 0.9, 0.8),
    "coeffs": (1.0, 0.7, 0.5, 0.3),
    "train_size": 1048,
    "val_size": 256,
    "batch_size": 64,
    "noise_std": 1e-2,
    "d_model": 4,
    "n_layers": 1,
    "epochs": 60,
    "lr": 1e-3,
}

DEFAULT_LOG_PATH = "outputs/time_series_experiments/ts_runs.jsonl"
DEFAULT_FIG_PATH = "outputs/time_series_experiments/figure.png"


def _iterable_to_list(x: Iterable[float] | tuple[float, ...]) -> list[float]:
    return list(x) if not isinstance(x, list) else x


def add_common_args(parser, include_log_path: bool = True, include_out_path: bool = False) -> None:
    parser.add_argument("--T", type=int, default=COMMON_DEFAULTS["T"])
    parser.add_argument("--lambdas", type=float, nargs="+", default=_iterable_to_list(COMMON_DEFAULTS["lambdas"]))
    parser.add_argument("--coeffs", type=float, nargs="+", default=_iterable_to_list(COMMON_DEFAULTS["coeffs"]))
    parser.add_argument("--train_size", type=int, default=COMMON_DEFAULTS["train_size"])
    parser.add_argument("--val_size", type=int, default=COMMON_DEFAULTS["val_size"])
    parser.add_argument("--batch_size", type=int, default=COMMON_DEFAULTS["batch_size"])
    parser.add_argument("--noise_std", type=float, default=COMMON_DEFAULTS["noise_std"])
    parser.add_argument("--d_model", type=int, default=COMMON_DEFAULTS["d_model"])
    parser.add_argument("--n_layers", type=int, default=COMMON_DEFAULTS["n_layers"])
    parser.add_argument("--epochs", type=int, default=COMMON_DEFAULTS["epochs"])
    parser.add_argument("--lr", type=float, default=COMMON_DEFAULTS["lr"])

    if include_log_path:
        parser.add_argument(
            "--log_path",
            type=str,
            default=DEFAULT_LOG_PATH,
            help="Path to append JSONL run summaries to.",
        )
    if include_out_path:
        parser.add_argument(
            "--out_path",
            type=str,
            default=DEFAULT_FIG_PATH,
        )
