# TASK_SPEC: lorenz_attractor_localrule_v1
import argparse
import sys
from pathlib import Path

import numpy as np

TASK_DIR = Path(__file__).resolve().parents[1]
if str(TASK_DIR) not in sys.path:
    sys.path.insert(0, str(TASK_DIR))

from common.sequence_core import (  # noqa: E402
    TorchLocalRuleRNN,
    build_plot_path,
    plot_trajectory_predictions,
    predict_sequence_outputs,
    predict_sequence_rollout,
    resolve_plot_context,
    train_batches,
)


def step_lorenz(
    state: np.ndarray,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
) -> np.ndarray:
    x, y, z = state
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return state + dt * np.array([dx, dy, dz], dtype=np.float32)


def generate_lorenz_dataset(
    num_samples: int,
    seq_len: int,
    horizon: int,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
    warmup: int,
    seed: int,
) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    total_len = seq_len + horizon
    inputs = np.zeros((num_samples, 3, seq_len), dtype=np.float32)
    targets = np.zeros((num_samples, 3, seq_len), dtype=np.float32)
    for i in range(num_samples):
        state = rng.normal(scale=0.5, size=(3,)).astype(np.float32)
        for _ in range(warmup):
            state = step_lorenz(state, dt, sigma, rho, beta)
        states = []
        for _ in range(total_len):
            state = step_lorenz(state, dt, sigma, rho, beta)
            states.append(state.copy())
        series = np.stack(states, axis=0)
        inputs[i] = series[:seq_len].T
        targets[i] = series[horizon : horizon + seq_len].T
    return inputs, targets


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="lorenz_attractor_localrule")
    
    # --- MODIFICATION START: Optimized Hyperparameters for Better Fit ---
    # Increased epochs for better convergence
    parser.add_argument("--epochs", type=int, default=100) 
    # Increased batch size slightly for stability with larger samples
    parser.add_argument("--batch-size", type=int, default=64) 
    # Significantly increased hidden size to capture chaotic dynamics
    parser.add_argument("--hidden", type=int, default=512) 
    # Reduced learning rate slightly to prevent oscillation with larger network
    parser.add_argument("--lr", type=float, default=5e-4) 
    # Increased training samples to cover more of the phase space
    parser.add_argument("--train-samples", type=int, default=4000)
    # --- MODIFICATION END ---

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--test-samples", type=int, default=400)
    parser.add_argument("--seq-len", type=int, default=1000)
    parser.add_argument("--pred-horizon", type=int, default=3)
    parser.add_argument("--dt", type=float, default=0.01)
    parser.add_argument("--sigma", type=float, default=10.0)
    parser.add_argument("--rho", type=float, default=28.0)
    parser.add_argument("--beta", type=float, default=2.6666667)
    parser.add_argument("--warmup", type=int, default=1000)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-mode", type=str, default="teacher")
    parser.add_argument("--pred-warmup", type=int, default=1)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gain = 2.593
    task_label = "Lorenz Attractor Prediction (Local Rule)"

    total_samples = args.train_samples + args.test_samples
    inputs, targets = generate_lorenz_dataset(
        num_samples=total_samples,
        seq_len=args.seq_len,
        horizon=max(1, args.pred_horizon),
        dt=args.dt,
        sigma=args.sigma,
        rho=args.rho,
        beta=args.beta,
        warmup=args.warmup,
        seed=args.seed,
    )
    train_inputs = inputs[: args.train_samples]
    train_targets = targets[: args.train_samples]
    test_inputs = inputs[args.train_samples :]
    test_targets = targets[args.train_samples :]

    plot_dir = None
    plot_tag = None
    if args.plot:
        plot_dir, plot_tag = resolve_plot_context(args, task_label)

    model = TorchLocalRuleRNN(
        train_inputs.shape[1],
        args.hidden,
        train_targets.shape[1],
        eta=args.lr,
        loss_mode="mse",
        seed=args.seed,
    )
    model.initialize_weights_with_gain(gain, seed=args.seed)

    log_every = max(1, args.epochs // 10) # Log more frequently
    print(f"Start Training: Hidden={args.hidden}, Epochs={args.epochs}, Samples={args.train_samples}")
    
    for epoch in range(args.epochs):
        train_batches(
            model,
            train_inputs,
            train_targets,
            args.batch_size,
            1,
            args.seed + 10 + epoch,
            epoch_offset=epoch,
        )
        if (epoch + 1) % log_every == 0 or (epoch + 1) == args.epochs:
            print(f"[Local Rule] epoch={epoch+1:03d} | gain={gain:.3f}")

    if not args.plot:
        return
    if test_inputs.shape[0] == 0:
        print("[PLOT] Skipping prediction plot: empty test set.")
        return

    sample_input = test_inputs[0]
    sample_target = test_targets[0]
    pred_mode = str(args.pred_mode).lower()
    if pred_mode in {"autoregressive", "ar"}:
        pred_mode = "rollout"
    pred_warmup = max(1, min(int(args.pred_warmup), sample_input.shape[1]))
    if pred_mode == "rollout":
        pred = predict_sequence_rollout(model, sample_input, warmup_steps=pred_warmup)
    else:
        pred_mode = "teacher"
        pred = predict_sequence_outputs(model, sample_input)

    pred_plot_path = None
    if plot_dir is not None and plot_tag is not None:
        pred_plot_path = build_plot_path(plot_dir, plot_tag, "pred_trajectory")
    plot_skip = pred_warmup
    plot_trajectory_predictions(
        task_label,
        {"Local Rule": pred},
        sample_target,
        dims=[0, 1, 2],
        plot_path=pred_plot_path,
        show=args.plot,
        skip_steps=plot_skip,
    )


if __name__ == "__main__":
    main()