# TASK_SPEC: movie_clip_replay_v1
import argparse
import math

import numpy as np

from sequence_utils import run_regression_task


def upsample_frames(frames: np.ndarray, seq_len: int) -> np.ndarray:
    if frames.shape[0] == seq_len:
        return frames
    if seq_len % frames.shape[0] != 0:
        raise ValueError("seq_len must be a multiple of base_frames.")
    repeat = seq_len // frames.shape[0]
    return np.repeat(frames, repeat, axis=0)


def generate_movie_clip(
    rng: np.random.Generator,
    base_frames: int,
    frame_h: int,
    frame_w: int,
    channels: int,
    num_blobs: int,
    blob_sigma: float,
    speed_min: float,
    speed_max: float,
    noise_std: float,
) -> np.ndarray:
    grid_x, grid_y = np.meshgrid(
        np.linspace(0.0, 1.0, frame_w, dtype=np.float32),
        np.linspace(0.0, 1.0, frame_h, dtype=np.float32),
    )
    frames = np.zeros((base_frames, frame_h, frame_w, channels), dtype=np.float32)

    for _ in range(num_blobs):
        pos_x, pos_y = rng.uniform(0.1, 0.9, size=2)
        speed = rng.uniform(speed_min, speed_max)
        angle = rng.uniform(0.0, 2.0 * math.pi)
        vel_x = speed * math.cos(angle)
        vel_y = speed * math.sin(angle)
        color = rng.uniform(0.3, 1.0, size=(channels,))
        for t in range(base_frames):
            cx = (pos_x + vel_x * t) % 1.0
            cy = (pos_y + vel_y * t) % 1.0
            dist2 = (grid_x - cx) ** 2 + (grid_y - cy) ** 2
            blob = np.exp(-dist2 / (2.0 * blob_sigma**2)).astype(np.float32)
            frames[t] += blob[..., None] * color

    if noise_std > 0.0:
        frames += rng.normal(0.0, noise_std, size=frames.shape).astype(np.float32)
    return np.clip(frames, 0.0, 1.0)


def generate_movie_dataset(
    num_samples: int,
    seq_len: int,
    base_frames: int,
    frame_h: int,
    frame_w: int,
    channels: int,
    num_blobs: int,
    blob_sigma: float,
    speed_min: float,
    speed_max: float,
    noise_std: float,
    seed: int,
) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    frame_size = frame_h * frame_w * channels
    # Float16 keeps the 22k x 1000 sequences manageable in memory.
    inputs = np.zeros((num_samples, frame_size, seq_len), dtype=np.float16)
    targets = np.zeros((num_samples, frame_size, seq_len), dtype=np.float16)

    for idx in range(num_samples):
        clip = generate_movie_clip(
            rng,
            base_frames,
            frame_h,
            frame_w,
            channels,
            num_blobs,
            blob_sigma,
            speed_min,
            speed_max,
            noise_std,
        )
        clip = upsample_frames(clip, seq_len)
        flat = clip.reshape(seq_len, -1).astype(np.float32)
        inputs[idx] = flat.T.astype(np.float16)
        targets[idx, :, :-1] = flat[1:].T.astype(np.float16)
        targets[idx, :, -1] = flat[-1].astype(np.float16)

    return inputs, targets


def build_time_indices(seq_len: int) -> list[int]:
    if seq_len <= 0:
        return [0]
    candidates = [
        0,
        seq_len // 5,
        seq_len // 3,
        seq_len // 2,
        (2 * seq_len) // 3,
        seq_len - 1,
    ]
    out: list[int] = []
    for idx in candidates:
        if 0 <= idx < seq_len and idx not in out:
            out.append(idx)
    return out


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="movie_clip_replay")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--hidden", type=int, default=256)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-samples", type=int, default=4)
    parser.add_argument("--test-samples", type=int, default=2)
    parser.add_argument("--seq-len", type=int, default=500)
    parser.add_argument("--base-frames", type=int, default=100)
    parser.add_argument("--frame-h", type=int, default=80)
    parser.add_argument("--frame-w", type=int, default=92)
    parser.add_argument("--channels", type=int, default=3)
    parser.add_argument("--num-blobs", type=int, default=3)
    parser.add_argument("--blob-sigma", type=float, default=0.08)
    parser.add_argument("--speed-min", type=float, default=0.004)
    parser.add_argument("--speed-max", type=float, default=0.02)
    parser.add_argument("--noise-std", type=float, default=0.01)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-local-only", action="store_true")
    parser.add_argument("--pred-all-models", action="store_true")
    parser.add_argument("--pred-mode", type=str, default="rollout")
    parser.add_argument("--pred-warmup", type=int, default=30)
    parser.add_argument("--eval-mode", type=str, default=None)
    parser.add_argument("--eval-warmup", type=int, default=None)
    parser.add_argument("--bptt-mode", type=str, choices=["full", "tbptt"], default="tbptt")
    parser.add_argument("--bptt-steps", type=int, default=30)
    args = parser.parse_args()
    args.plot = not args.no_plot

    if args.seq_len % args.base_frames != 0:
        raise ValueError("--seq-len must be a multiple of --base-frames.")

    gains_default = np.linspace(0.1, 2.2, 10, endpoint=False)
    if args.gains:
        gains = np.array([float(x) for x in args.gains.split(",") if x.strip()], dtype=np.float32)
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    total_samples = args.train_samples + args.test_samples
    inputs, targets = generate_movie_dataset(
        num_samples=total_samples,
        seq_len=args.seq_len,
        base_frames=args.base_frames,
        frame_h=args.frame_h,
        frame_w=args.frame_w,
        channels=args.channels,
        num_blobs=args.num_blobs,
        blob_sigma=args.blob_sigma,
        speed_min=args.speed_min,
        speed_max=args.speed_max,
        noise_std=args.noise_std,
        seed=args.seed,
    )
    train_inputs = inputs[: args.train_samples]
    train_targets = targets[: args.train_samples]
    test_inputs = inputs[args.train_samples :]
    test_targets = targets[args.train_samples :]

    if args.pred_local_only:
        pred_models = ["Local Rule"]
    elif args.pred_all_models:
        pred_models = None
    else:
        pred_models = ["Local Rule", "BPTT"]

    eval_mode = args.eval_mode if args.eval_mode else args.pred_mode
    eval_warmup = args.eval_warmup if args.eval_warmup is not None else args.pred_warmup

    bptt_steps = max(1, min(int(args.bptt_steps), int(args.seq_len)))
    task_data = {
        "task_type": "regression",
        "task_name": "Movie Clip Replay (80x92 RGB)",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "bptt_config": {
            "mode": args.bptt_mode,
            "tbptt_steps": bptt_steps,
        },
        "evaluation": {
            "mode": eval_mode,
            "warmup_steps": eval_warmup,
        },
        "prediction_plot": {
            "type": "image",
            "frame_h": args.frame_h,
            "frame_w": args.frame_w,
            "frame_channels": args.channels,
            "time_indices": build_time_indices(args.seq_len),
            "models": pred_models,
            "mode": args.pred_mode,
            "warmup_steps": args.pred_warmup,
            "show_input": True,
        },
    }

    print("STAGE 1: Scanning gains for Movie Clip replay...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Movie Clip replay...")
    run_regression_task(task_data, args, gains)


if __name__ == "__main__":
    main()
