# TASK_SPEC: lorenz_attractor_v1
import argparse

import numpy as np

from sequence_utils import run_regression_task


def step_lorenz(
    state: np.ndarray,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
) -> np.ndarray:
    x, y, z = state
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return state + dt * np.array([dx, dy, dz], dtype=np.float32)


def generate_lorenz_dataset(
    num_samples: int,
    seq_len: int,
    horizon: int,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
    warmup: int,
    seed: int,
) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    total_len = seq_len + horizon
    inputs = np.zeros((num_samples, 3, seq_len), dtype=np.float32)
    targets = np.zeros((num_samples, 3, seq_len), dtype=np.float32)
    for i in range(num_samples):
        state = rng.normal(scale=0.5, size=(3,)).astype(np.float32)
        for _ in range(warmup):
            state = step_lorenz(state, dt, sigma, rho, beta)
        states = []
        for _ in range(total_len):
            state = step_lorenz(state, dt, sigma, rho, beta)
            states.append(state.copy())
        series = np.stack(states, axis=0)
        inputs[i] = series[:seq_len].T
        targets[i] = series[horizon : horizon + seq_len].T
    return inputs, targets


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="lorenz_attractor")
    parser.add_argument("--epochs", type=int, default=35)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=192)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-samples", type=int, default=2000)
    parser.add_argument("--test-samples", type=int, default=400)
    parser.add_argument("--seq-len", type=int, default=1000)
    parser.add_argument("--pred-horizon", type=int, default=3)
    parser.add_argument("--dt", type=float, default=0.01)
    parser.add_argument("--sigma", type=float, default=10.0)
    parser.add_argument("--rho", type=float, default=28.0)
    parser.add_argument("--beta", type=float, default=2.6666667)
    parser.add_argument("--warmup", type=int, default=1000)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-all-models", action="store_true")
    parser.add_argument("--pred-mode", type=str, default="teacher")
    parser.add_argument("--pred-warmup", type=int, default=1)
    parser.add_argument("--eval-mode", type=str, default=None)
    parser.add_argument("--eval-warmup", type=int, default=None)
    parser.add_argument("--bptt-mode", type=str, choices=["full", "tbptt"], default="tbptt")
    parser.add_argument("--bptt-steps", type=int, default=50)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.1, 3.5, 15, endpoint=False)
    if args.gains:
        gains = np.array(
            [float(x) for x in args.gains.split(",") if x.strip()],
            dtype=np.float32,
        )
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    total_samples = args.train_samples + args.test_samples
    inputs, targets = generate_lorenz_dataset(
        num_samples=total_samples,
        seq_len=args.seq_len,
        horizon=max(1, args.pred_horizon),
        dt=args.dt,
        sigma=args.sigma,
        rho=args.rho,
        beta=args.beta,
        warmup=args.warmup,
        seed=args.seed,
    )
    train_inputs = inputs[: args.train_samples]
    train_targets = targets[: args.train_samples]
    test_inputs = inputs[args.train_samples :]
    test_targets = targets[args.train_samples :]

    eval_mode = args.eval_mode if args.eval_mode else args.pred_mode
    eval_warmup = args.eval_warmup if args.eval_warmup is not None else args.pred_warmup

    bptt_steps = max(1, min(int(args.bptt_steps), int(args.seq_len)))
    task_data = {
        "task_type": "regression",
        "task_name": "Lorenz Attractor Prediction",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "bptt_config": {
            "mode": args.bptt_mode,
            "tbptt_steps": bptt_steps,
        },
        "evaluation": {
            "mode": eval_mode,
            "warmup_steps": eval_warmup,
        },
        "prediction_plot": {
            "type": "trajectory3d",
            "dims": [0, 1, 2],
            "mode": args.pred_mode,
            "warmup_steps": args.pred_warmup,
            "models": None if args.pred_all_models else ["Local Rule", "BPTT"],
        },
    }

    print("STAGE 1: Scanning gains for Lorenz attractor prediction...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Lorenz attractor prediction...")
    run_regression_task(task_data, args, gains)


if __name__ == "__main__":
    main()
