# TASK_SPEC: dvs_gesture_cnn_compare_v1
import argparse
import os

import numpy as np

from sequence_utils import load_dvs_gesture_images, run_classification_task


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="dvs_gesture")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=256)
    parser.add_argument("--enc-channels", type=str, default="32,64")
    parser.add_argument("--steps", type=int, default=12)
    parser.add_argument("--kernel-size", type=int, default=3)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-limit", type=int, default=None)
    parser.add_argument("--test-limit", type=int, default=None)
    parser.add_argument("--gesture-npz", type=str, default=None)
    parser.add_argument("--gesture-root", type=str, default=None)
    parser.add_argument("--gesture-time-bins", type=int, default=20)
    parser.add_argument(
        "--gesture-spatial-downsample",
        type=int,
        default=1,
        help="Spatial downsample factor for frames (e.g. 4 -> 32x32). Helps reduce RAM usage.",
    )
    parser.add_argument("--gesture-no-polarity", action="store_true")
    parser.add_argument(
        "--gesture-rebuild",
        action="store_true",
        help="Delete cached DVS-Gesture npz before running (forces re-download/processing).",
    )
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--time-weighting", type=str, choices=["none", "final", "late"], default="none")
    parser.add_argument("--step-labels", type=str, choices=["final", "fptt"], default="final")
    parser.add_argument("--train-encoder", dest="train_encoder", action="store_true")
    parser.add_argument("--freeze-encoder", dest="train_encoder", action="store_false")
    parser.add_argument("--tbptt-short", type=int, default=1)
    parser.add_argument("--tbptt-long", type=int, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.set_defaults(train_encoder=True)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.05, 2.3, 12, endpoint=False)
    if args.gains:
        gains = np.array([float(x) for x in args.gains.split(",") if x.strip()], dtype=np.float32)
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    if args.gesture_rebuild:
        default_root = args.gesture_root or os.path.join("data", "dvs_gesture")
        cached = args.gesture_npz or os.path.join(default_root, "dvs_gesture.npz")
        if os.path.exists(cached):
            print(f"[DVS] Removing cached npz: {cached}")
            os.remove(cached)

    try:
        train_images, train_labels, test_images, test_labels = load_dvs_gesture_images(
            npz_path=args.gesture_npz,
            train_limit=args.train_limit,
            test_limit=args.test_limit,
            root=args.gesture_root,
            time_bins=args.gesture_time_bins,
            spatial_downsample=args.gesture_spatial_downsample,
            use_polarity=not args.gesture_no_polarity,
        )
    except RuntimeError as exc:
        print(exc)
        return

    max_train = int(np.max(train_labels)) if train_labels.size else -1
    max_test = int(np.max(test_labels)) if test_labels.size else -1
    output_size = max(max_train, max_test) + 1

    task_data = {
        "task_type": "classification",
        "task_name": "DVS-Gesture (Static Frames)",
        "train_images": train_images,
        "train_labels": train_labels,
        "test_images": test_images,
        "test_labels": test_labels,
        "input_channels": train_images.shape[1],
        "output_size": output_size,
        "time_weighting": args.time_weighting,
    }

    print("STAGE 1: Scanning gains for DVS-Gesture (Static Frames)...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT/TBPTT on DVS-Gesture (Static Frames)...")
    run_classification_task(task_data, args, gains)


if __name__ == "__main__":
    main()
