# TASK_SPEC: adding_task_v1
import argparse

import numpy as np

from sequence_utils import run_regression_task


def generate_adding_task(
    num_samples: int,
    seq_len: int,
    rng: np.random.Generator,
) -> tuple[np.ndarray, np.ndarray]:
    if seq_len < 2:
        seq_len = 2
    values = rng.random((num_samples, seq_len)).astype(np.float32)
    mask = np.zeros((num_samples, seq_len), dtype=np.float32)
    mid = min(max(1, seq_len // 2), seq_len - 1)
    for i in range(num_samples):
        p1 = int(rng.integers(0, mid))
        p2 = int(rng.integers(mid, seq_len))
        mask[i, p1] = 1.0
        mask[i, p2] = 1.0
    inputs = np.stack([values, mask], axis=1).astype(np.float32)
    weighted = values * mask
    cumulative = np.cumsum(weighted, axis=1, dtype=np.float32)
    targets = cumulative[:, None, :].astype(np.float32)
    return inputs, targets


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="adding_task")
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-samples", type=int, default=8000)
    parser.add_argument("--test-samples", type=int, default=2000)
    parser.add_argument("--seq-len", type=int, default=60)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument(
        "--time-weighting",
        type=str,
        choices=["none", "final", "late"],
        default="none",
    )
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-all-models", action="store_true")
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.3, 1.8, 8, endpoint=False)
    if args.gains:
        gains = np.array(
            [float(x) for x in args.gains.split(",") if x.strip()],
            dtype=np.float32,
        )
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    rng = np.random.default_rng(args.seed)
    train_inputs, train_targets = generate_adding_task(
        args.train_samples, args.seq_len, rng
    )
    test_inputs, test_targets = generate_adding_task(
        args.test_samples, args.seq_len, rng
    )

    task_data = {
        "task_type": "regression",
        "task_name": "Adding Task",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "time_weighting": args.time_weighting,
        "prediction_plot": {
            "type": "timeseries",
            "dims": [0],
            "models": None if args.pred_all_models else ["Local Rule", "BPTT"],
        },
    }

    print("STAGE 1: Scanning gains for Adding Task...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Adding Task...")
    run_regression_task(task_data, args, gains)


if __name__ == "__main__":
    main()
