import argparse
import numpy as np

from sequence_utils import generate_lorenz_sequences, run_regression_task


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="lorenz_image")
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=192)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-samples", type=int, default=10000)
    parser.add_argument("--test-samples", type=int, default=1000)
    parser.add_argument("--seq-len", type=int, default=30)
    parser.add_argument("--frame-h", type=int, default=16)
    parser.add_argument("--frame-w", type=int, default=16)
    parser.add_argument("--dt", type=float, default=0.01)
    parser.add_argument("--sigma", type=float, default=10.0)
    parser.add_argument("--rho", type=float, default=28.0)
    parser.add_argument("--beta", type=float, default=2.6667)
    parser.add_argument("--warmup", type=int, default=100)
    parser.add_argument("--blur-sigma", type=float, default=1.2)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    parser.add_argument("--pred-local-only", action="store_true")
    parser.add_argument("--pred-all-models", action="store_true")
    parser.add_argument("--pred-mode", type=str, default="rollout")
    parser.add_argument("--pred-warmup", type=int, default=1)
    parser.add_argument("--eval-mode", type=str, default=None)
    parser.add_argument("--eval-warmup", type=int, default=None)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(1.0, 1.3, 3, endpoint=False)
    if args.gains:
        gains = np.array([float(x) for x in args.gains.split(",") if x.strip()], dtype=np.float32)
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    total_samples = args.train_samples + args.test_samples
    inputs, targets = generate_lorenz_sequences(
        num_samples=total_samples,
        seq_len=args.seq_len,
        frame_h=args.frame_h,
        frame_w=args.frame_w,
        dt=args.dt,
        sigma=args.sigma,
        rho=args.rho,
        beta=args.beta,
        warmup=args.warmup,
        seed=args.seed,
        blur_sigma=args.blur_sigma,
    )
    train_inputs = inputs[: args.train_samples]
    train_targets = targets[: args.train_samples]
    test_inputs = inputs[args.train_samples :]
    test_targets = targets[args.train_samples :]

    if args.pred_local_only:
        pred_models = ["Local Rule"]
    elif args.pred_all_models:
        pred_models = None
    else:
        pred_models = ["Local Rule", "BPTT"]

    eval_mode = args.eval_mode if args.eval_mode else args.pred_mode
    eval_warmup = args.eval_warmup if args.eval_warmup is not None else args.pred_warmup

    task_data = {
        "task_type": "regression",
        "task_name": "Lorenz Image Prediction",
        "train_inputs": train_inputs,
        "train_targets": train_targets,
        "test_inputs": test_inputs,
        "test_targets": test_targets,
        "input_size": train_inputs.shape[1],
        "output_size": train_targets.shape[1],
        "evaluation": {
            "mode": eval_mode,
            "warmup_steps": eval_warmup,
        },
        "prediction_plot": {
            "type": "image",
            "frame_h": args.frame_h,
            "frame_w": args.frame_w,
            "models": pred_models,
            "mode": args.pred_mode,
            "warmup_steps": args.pred_warmup,
            "show_input": True,
        },
    }

    print("STAGE 1: Scanning gains for Lorenz image prediction...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on Lorenz image prediction...")
    run_regression_task(task_data, args, gains)


if __name__ == "__main__":
    main()
