# TASK_SPEC: row_mnist_fptt_only_v1
import argparse
import math
import time

import numpy as np
import torch

from sequence_utils import build_repeated_targets, load_mnist_images
from common.sequence_core import (
    DEFAULT_DEVICE,
    StrictFPTTClassifier,
    build_lyapunov_driver,
    calculate_lyapunov_exponent_numpy,
    extract_params,
    estimate_model_complexity,
    estimate_training_counts,
    evaluate_classifier_final_step,
    load_params,
    plot_comparison_results,
    resolve_plot_context,
    save_results_summary,
    split_train_val,
    train_batches,
)
from methods.strict_fptt import OracleBufferStore


def load_row_mnist_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_mnist_images(
        train_limit=train_limit,
        test_limit=test_limit,
    )
    train_inputs = np.transpose(train_images, (0, 2, 1)).astype(np.float32)
    test_inputs = np.transpose(test_images, (0, 2, 1)).astype(np.float32)
    time_steps = train_inputs.shape[2]
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)
    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="row_mnist_fptt")
    parser.add_argument("--epochs", type=int, default=40)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-limit", type=int, default=60000)
    parser.add_argument("--test-limit", type=int, default=10000)
    parser.add_argument("--gain", type=float, default=0.533)
    parser.add_argument("--parts", type=int, default=10)
    parser.add_argument("--clip", type=float, default=1.0)
    parser.add_argument("--alpha", type=float, default=0.1)
    parser.add_argument("--beta", type=float, default=0.5)
    parser.add_argument("--rho", type=float, default=0.0)
    parser.add_argument("--lmbda", type=float, default=1.0)
    parser.add_argument("--oracle-momentum", type=float, default=1.0)
    parser.add_argument("--warmup-epochs", type=int, default=20)
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    parser.add_argument("--plot-dir", type=str, default=None)
    parser.add_argument("--plot-tag", type=str, default=None)
    args = parser.parse_args()
    args.plot = not args.no_plot

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    (
        train_inputs,
        train_targets,
        train_labels,
        test_inputs,
        test_targets,
        test_labels,
    ) = load_row_mnist_sequences(
        train_limit=args.train_limit,
        test_limit=args.test_limit,
    )
    rng = np.random.default_rng(args.seed)
    tr_inputs, tr_targets, tr_labels, val_inputs, val_targets, val_labels = split_train_val(
        train_inputs, train_targets, train_labels, 0.1, rng
    )

    task_label = "Row-Sequential MNIST (FPTT-only)"
    plot_dir = None
    plot_tag = None
    if args.plot:
        plot_dir, plot_tag = resolve_plot_context(args, task_label)

    device = DEFAULT_DEVICE
    oracle_id = args.task
    OracleBufferStore.reset(oracle_id)
    model = StrictFPTTClassifier(
        train_inputs.shape[1],
        args.hidden,
        train_targets.shape[1],
        eta=args.lr,
        parts=args.parts,
        clip=args.clip,
        alpha=args.alpha,
        beta=args.beta,
        rho=args.rho,
        lmbda=args.lmbda,
        oracle_momentum=args.oracle_momentum,
        warmup_epochs=args.warmup_epochs,
        oracle_id=oracle_id,
        label_mode="last",
        use_oracle=True,
        device=device,
    )
    model.initialize_weights_with_gain(args.gain, seed=args.seed)

    lyapunov_driver = build_lyapunov_driver(val_inputs)
    lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
    history: list[float] = []
    log_every = max(1, args.epochs // 5)
    train_runtime_sec = 0.0
    eval_runtime_sec = 0.0
    best_epoch = 0
    best_val_acc = -float("inf")
    best_val_loss = float("inf")
    best_params = None

    for epoch in range(args.epochs):
        train_start = time.perf_counter()
        train_batches(
            model,
            tr_inputs,
            tr_targets,
            args.batch_size,
            1,
            args.seed + 10 + epoch,
            epoch_offset=epoch,
        )
        train_runtime_sec += time.perf_counter() - train_start
        eval_start = time.perf_counter()
        val_loss, val_acc = evaluate_classifier_final_step(
            model,
            val_inputs,
            val_targets,
            val_labels,
            args.batch_size,
        )
        eval_runtime_sec += time.perf_counter() - eval_start
        history.append(float(val_acc))
        if val_acc > best_val_acc:
            best_val_acc = float(val_acc)
            best_val_loss = float(val_loss)
            best_epoch = epoch + 1
            best_params = extract_params(model)
        if (epoch + 1) % log_every == 0 or (epoch + 1) == args.epochs:
            print(
                f"[FPTT] epoch={epoch+1:02d} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f}"
            )

    if best_params is not None:
        load_params(model, best_params)
    eval_start = time.perf_counter()
    test_loss, test_acc = evaluate_classifier_final_step(
        model,
        test_inputs,
        test_targets,
        test_labels,
        args.batch_size,
    )
    eval_runtime_sec += time.perf_counter() - eval_start
    lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
    delta = lambda_post - lambda_pre
    print(
        f"[FPTT] test_acc={test_acc:.4f} | test_loss={test_loss:.4f} | "
        f"best_val_acc={best_val_acc:.4f} (epoch={best_epoch:02d}) | "
        f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
    )

    batches_per_epoch = math.ceil(train_inputs.shape[0] / args.batch_size)
    complexity = estimate_model_complexity(model)
    update_stats = estimate_training_counts(model, train_inputs.shape[2], batches_per_epoch, args.epochs)
    runtime_sec = train_runtime_sec + eval_runtime_sec
    updates_total = update_stats["updates_total"]
    steps_total = update_stats["steps_total"]
    runtime_per_update_sec = train_runtime_sec / updates_total if updates_total > 0 else float("nan")
    runtime_per_step_sec = train_runtime_sec / steps_total if steps_total > 0 else float("nan")

    if args.plot and plot_dir is not None and plot_tag is not None:
        results = {
            "FPTT": {
                "metric": float(test_acc),
                "val_metric": float(best_val_acc),
                "val_loss": float(best_val_loss),
                "best_epoch": int(best_epoch),
                "lyap_pre": float(lambda_pre),
                "lyap_post": float(lambda_post),
                "history": history,
                "complexity_params": float(complexity["params"]),
                "complexity_state": float(complexity["state"]),
                "complexity_total": float(complexity["total"]),
                "runtime_sec": float(runtime_sec),
                "train_runtime_sec": float(train_runtime_sec),
                "eval_runtime_sec": float(eval_runtime_sec),
                "runtime_per_update_sec": float(runtime_per_update_sec),
                "runtime_per_step_sec": float(runtime_per_step_sec),
                "batches_per_epoch": int(update_stats["batches_per_epoch"]),
                "time_steps": int(update_stats["time_steps"]),
                "update_factor": int(update_stats["update_factor"]),
                "updates_per_epoch": int(update_stats["updates_per_epoch"]),
                "updates_total": int(update_stats["updates_total"]),
                "steps_total": int(update_stats["steps_total"]),
            }
        }
        save_results_summary(task_label, "Test Accuracy", results, plot_dir, plot_tag)
        plot_comparison_results(
            task_label,
            results,
            metric_label="Test Accuracy",
            history_label="Val Accuracy",
            higher_is_better=True,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )


if __name__ == "__main__":
    main()
