"""Synthetic decay mixture experiment: N=1 vs N>1 diagonal SSM."""

from __future__ import annotations

import argparse
from typing import List

import numpy as np
import torch

from .data import make_decay_dataloaders
from .models_mamba import MambaSSMRegressor
from .train_ts import log_run, train_regression


def run_experiment(args: argparse.Namespace) -> List[dict]:
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, val_loader = make_decay_dataloaders(
        T=args.T,
        lambdas=tuple(args.lambdas),
        coeffs=tuple(args.coeffs),
        train_size=args.train_size,
        val_size=args.val_size,
        batch_size=args.batch_size,
        noise_std=args.noise_std,
        seed=args.seed,
    )

    results = []
    for d_state in args.N_values:
        for model_kind in args.models:
            if model_kind == "mamba_diag_exp":
                model = MambaSSMRegressor(
                    block_kind="diag_exp",
                    d_model=args.d_model,
                    d_state=d_state,
                    n_layers=args.n_layers,
                )
            elif model_kind == "mamba_simple":
                model = MambaSSMRegressor(
                    block_kind="simple",
                    d_model=args.d_model,
                    d_state=d_state,
                    n_layers=args.n_layers,
                )
            else:
                raise ValueError(f"Unknown model_kind: {model_kind}")

            history = train_regression(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                device=device,
                epochs=args.epochs,
                lr=args.lr,
                verbose=args.verbose,
            )
            final_metrics = history["history"][-1]
            best_val = min(h["val_loss"] for h in history["history"])
            summary = {
                "task": "synthetic_decays",
                "model_kind": model_kind,
                "d_state": d_state,
                "d_model": args.d_model,
                "n_layers": args.n_layers,
                "dataset": {
                    "T": args.T,
                    "lambdas": list(args.lambdas),
                    "coeffs": list(args.coeffs),
                    "noise_std": args.noise_std,
                    "train_size": args.train_size,
                    "val_size": args.val_size,
                },
                "train_config": {
                    "epochs": args.epochs,
                    "lr": args.lr,
                    "batch_size": args.batch_size,
                    "seed": args.seed,
                },
                "history": history["history"],
                "final_train_loss": final_metrics["train_loss"],
                "final_val_loss": final_metrics["val_loss"],
                "best_val_loss": best_val,
            }
            results.append(summary)
            print(f"model={model_kind} N={d_state} : final val MSE = {final_metrics['val_loss']:.6f}")
            log_run(args.log_path, summary)
    return results


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--T", type=int, default=60)
    parser.add_argument("--lambdas", type=float, nargs="+", default=[0.98, 0.94, 0.90, 0.80])
    parser.add_argument("--coeffs", type=float, nargs="+", default=[1.0, 0.7, 0.5, 0.3])
    parser.add_argument("--train_size", type=int, default=2048)
    parser.add_argument("--val_size", type=int, default=512)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--noise_std", type=float, default=0.01)
    parser.add_argument("--d_model", type=int, default=64)
    parser.add_argument("--n_layers", type=int, default=1)
    parser.add_argument("--N_values", type=int, nargs="+", default=[1, 2])
    parser.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=["mamba_diag_exp"],
        choices=["mamba_diag_exp", "mamba_simple"],
        help="Models to run; logged as `model_kind`.",
    )
    parser.add_argument("--epochs", type=int, default=60)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--log_path",
        type=str,
        default="outputs/time_series_experiments/ts_runs.jsonl",
        help="Path to append JSONL run summaries to.",
    )
    parser.add_argument("--verbose", action="store_true", help="Print per-epoch losses.")
    args = parser.parse_args()

    results = run_experiment(args)
    if results:
        print("\nFinished synthetic decay experiment.")


if __name__ == "__main__":
    main()
