# TASK_SPEC: ready_set_go_v1
import argparse

import numpy as np

from sequence_utils import run_regression_task


def generate_ready_set_go(
    num_samples: int,
    seq_len: int,
    min_set: int,
    max_set: int,
    gain: float,
    response_width: int,
    noise_std: float,
    rng: np.random.Generator,
) -> tuple[np.ndarray, np.ndarray]:
    if seq_len < 2:
        raise ValueError("seq_len must be >= 2.")
    min_set = max(1, int(min_set))
    max_set = max(1, int(max_set))
    response_width = max(1, int(response_width))
    seq_budget = max(1, seq_len - response_width)
    max_allowed = int(seq_budget / max(1.0, 2.0 + gain))
    max_allowed = max(1, max_allowed)
    max_set = min(max_set, max_allowed, seq_len - 1)
    if max_set < min_set:
        min_set = max_set

    inputs = np.zeros((num_samples, 2, seq_len), dtype=np.float32)
    targets = np.zeros((num_samples, 1, seq_len), dtype=np.float32)

    for i in range(num_samples):
        t_s = int(rng.integers(min_set, max_set + 1))
        response_time = int(round((2.0 + gain) * t_s))
        response_time = max(0, min(response_time, seq_len - 1))
        inputs[i, 0, 0] = 1.0
        inputs[i, 1, t_s] = 1.0
        t_end = min(seq_len, response_time + max(1, response_width))
        targets[i, 0, response_time:t_end] = 1.0

    if noise_std > 0:
        inputs += rng.normal(0.0, noise_std, size=inputs.shape).astype(np.float32)

    return inputs, targets


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="ready_set_go")
    parser.add_argument("--epochs", type=int, default=40)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-samples", type=int, default=4000)
    parser.add_argument("--test-samples", type=int, default=800)
    parser.add_argument("--seq-len", type=int, default=40)
    parser.add_argument("--min-set", type=int, default=10)
    parser.add_argument("--max-set", type=int, default=50)
    parser.add_argument("--gain", type=float, default=0.7)
    parser.add_argument("--response-width", type=int, default=2)
    parser.add_argument("--noise-std", type=float, default=0.01)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-all-models", action="store_true")
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.3, 1.8, 6, endpoint=False)
    if args.gains:
        gains = np.array(
            [float(x) for x in args.gains.split(",") if x.strip()],
            dtype=np.float32,
        )
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    rng = np.random.default_rng(args.seed)
    train_inputs, train_targets = generate_ready_set_go(
        num_samples=args.train_samples,
        seq_len=args.seq_len,
        min_set=args.min_set,
        max_set=args.max_set,
        gain=args.gain,
        response_width=args.response_width,
        noise_std=args.noise_std,
        rng=rng,
    )
    test_inputs, test_targets = generate_ready_set_go(
        num_samples=args.test_samples,
        seq_len=args.seq_len,
        min_set=args.min_set,
        max_set=args.max_set,
        gain=args.gain,
        response_width=args.response_width,
        noise_std=args.noise_std,
        rng=rng,
    )

    task_data = {
        "task_type": "regression",
        "task_name": "Ready-Set-Go",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "prediction_plot": {
            "type": "timeseries",
            "dims": [0],
            "models": None if args.pred_all_models else ["Local Rule", "BPTT"],
        },
    }

    print("STAGE 1: Scanning gains for Ready-Set-Go...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Ready-Set-Go...")
    run_regression_task(task_data, args, gains)


if __name__ == "__main__":
    main()
