# TASK_SPEC: pixel_mnist_v1
import argparse

import numpy as np

from sequence_utils import (
    build_repeated_targets,
    load_mnist_images,
    run_classification_task,
)


def load_pixel_mnist_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_mnist_images(
        train_limit=train_limit,
        test_limit=test_limit,
    )
    train_flat = train_images.reshape(train_images.shape[0], -1)
    test_flat = test_images.reshape(test_images.shape[0], -1)
    train_inputs = train_flat[:, None, :].astype(np.float32)
    test_inputs = test_flat[:, None, :].astype(np.float32)
    time_steps = train_inputs.shape[2]
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)
    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="pixel_mnist")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-limit", type=int, default=60000)
    parser.add_argument("--test-limit", type=int, default=10000)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument(
        "--time-weighting",
        type=str,
        choices=["none", "final", "late"],
        default="none",
    )
    parser.add_argument("--step-labels", type=str, choices=["final", "fptt"], default="final")
    parser.add_argument("--tbptt-short", type=int, default=1)
    parser.add_argument("--tbptt-long", type=int, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.5, 1.6, 8, endpoint=False)
    if args.gains:
        gains = np.array(
            [float(x) for x in args.gains.split(",") if x.strip()],
            dtype=np.float32,
        )
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    (
        train_inputs,
        train_targets,
        train_labels,
        test_inputs,
        test_targets,
        test_labels,
    ) = load_pixel_mnist_sequences(
        train_limit=args.train_limit,
        test_limit=args.test_limit,
    )

    task_data = {
        "task_type": "classification",
        "task_name": "Pixel MNIST",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "train_labels": train_labels,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "test_labels": test_labels,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "time_weighting": args.time_weighting,
    }

    print("STAGE 1: Scanning gains for Pixel MNIST...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Pixel MNIST...")
    run_classification_task(task_data, args, gains)


if __name__ == "__main__":
    main()
