# TASK_SPEC: movie_prediction_v1
import argparse
import math

import numpy as np

from sequence_utils import run_regression_task


def generate_moving_wave_dataset(
    num_samples: int,
    seq_len: int,
    frame_h: int,
    frame_w: int,
    horizon_a: int,
    horizon_b: int,
    freq_min: float,
    freq_max: float,
    speed_min: float,
    speed_max: float,
    noise_std: float,
    seed: int,
) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    max_horizon = max(horizon_a, horizon_b)
    total_len = seq_len + max_horizon
    grid_x, grid_y = np.meshgrid(
        np.linspace(0.0, 1.0, frame_w),
        np.linspace(0.0, 1.0, frame_h),
    )
    frame_size = frame_h * frame_w
    inputs = np.zeros((num_samples, frame_size, seq_len), dtype=np.float32)
    targets = np.zeros((num_samples, frame_size * 2, seq_len), dtype=np.float32)

    for i in range(num_samples):
        freq_x = rng.uniform(freq_min, freq_max)
        freq_y = rng.uniform(freq_min, freq_max)
        phase = rng.uniform(0.0, 2.0 * math.pi)
        speed = rng.uniform(speed_min, speed_max)
        frames = []
        for t in range(total_len):
            phase_t = phase + speed * t
            wave = np.sin(2.0 * math.pi * (freq_x * grid_x + freq_y * grid_y + phase_t))
            frame = 0.5 * (wave + 1.0)
            if noise_std > 0:
                frame += rng.normal(0.0, noise_std, size=frame.shape)
            frames.append(frame.astype(np.float32))
        frames = np.stack(frames, axis=0)
        flat_frames = frames.reshape(total_len, frame_size)
        inputs[i] = flat_frames[:seq_len].T
        for t in range(seq_len):
            t_a = t + horizon_a
            t_b = t + horizon_b
            targets[i, :frame_size, t] = flat_frames[t_a]
            targets[i, frame_size:, t] = flat_frames[t_b]

    return inputs, targets


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="movie_prediction")
    parser.add_argument("--epochs", type=int, default=25)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-samples", type=int, default=1000)
    parser.add_argument("--test-samples", type=int, default=200)
    parser.add_argument("--seq-len", type=int, default=30)
    parser.add_argument("--frame-h", type=int, default=20)
    parser.add_argument("--frame-w", type=int, default=20)
    parser.add_argument("--horizon-a", type=int, default=6)
    parser.add_argument("--horizon-b", type=int, default=12)
    parser.add_argument("--freq-min", type=float, default=0.5)
    parser.add_argument("--freq-max", type=float, default=2.5)
    parser.add_argument("--speed-min", type=float, default=0.05)
    parser.add_argument("--speed-max", type=float, default=0.25)
    parser.add_argument("--noise-std", type=float, default=0.01)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-all-models", action="store_true")
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.3, 1.0, 4, endpoint=False)
    if args.gains:
        gains = np.array(
            [float(x) for x in args.gains.split(",") if x.strip()],
            dtype=np.float32,
        )
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    total_samples = args.train_samples + args.test_samples
    inputs, targets = generate_moving_wave_dataset(
        num_samples=total_samples,
        seq_len=args.seq_len,
        frame_h=args.frame_h,
        frame_w=args.frame_w,
        horizon_a=args.horizon_a,
        horizon_b=args.horizon_b,
        freq_min=args.freq_min,
        freq_max=args.freq_max,
        speed_min=args.speed_min,
        speed_max=args.speed_max,
        noise_std=args.noise_std,
        seed=args.seed,
    )
    train_inputs = inputs[: args.train_samples]
    train_targets = targets[: args.train_samples]
    test_inputs = inputs[args.train_samples :]
    test_targets = targets[args.train_samples :]

    task_data = {
        "task_type": "regression",
        "task_name": "Movie Prediction",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "prediction_plot": {
            "type": "image",
            "frame_h": args.frame_h,
            "frame_w": args.frame_w,
            "frame_index": 0,
            "models": None if args.pred_all_models else ["Local Rule", "BPTT"],
        },
    }

    print("STAGE 1: Scanning gains for Movie Prediction...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Movie Prediction...")
    run_regression_task(task_data, args, gains)


if __name__ == "__main__":
    main()
