from __future__ import annotations

import importlib.util
import math
import sys
from pathlib import Path

import numpy as np
import torch


def _load_module(name: str, path: Path):
    spec = importlib.util.spec_from_file_location(name, path)
    if spec is None or spec.loader is None:
        raise RuntimeError(f"Failed to load module spec: {path}")
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    spec.loader.exec_module(module)
    return module


def _is_finite(x: float) -> bool:
    return math.isfinite(float(x))


def main() -> int:
    repo_root = Path(__file__).resolve().parents[1]
    rnn_core = _load_module(
        "audit_rnn_core", repo_root / "Compare_RNN" / "task" / "common" / "sequence_core.py"
    )
    cnn_core = _load_module(
        "audit_cnn_core", repo_root / "Compare_CNN" / "task" / "common" / "conv_core.py"
    )

    device = torch.device("cpu")
    rng = np.random.default_rng(0)

    print("[SANITY] RNN: Local/BPTT/TBPTT/E-Prop/FPTT (tiny random batch)")
    b, input_size, hidden_size, output_size, time_steps = 4, 8, 16, 5, 12
    inputs = rng.standard_normal((b, input_size, time_steps), dtype=np.float32)
    labels = rng.integers(0, output_size, size=(b,), dtype=np.int64)
    targets = rnn_core.build_repeated_targets(labels, output_size, time_steps)

    inputs_t = torch.as_tensor(inputs, dtype=torch.float32, device=device)
    targets_t = torch.as_tensor(targets, dtype=torch.float32, device=device)
    h0 = torch.zeros((hidden_size, b), dtype=torch.float32, device=device)

    local = rnn_core.TorchLocalRuleRNN(
        input_size,
        hidden_size,
        output_size,
        eta=1e-3,
        loss_mode="ce",
        seed=0,
        device=device,
    )
    loss, _ = local.run_one_cycle_and_update_directly(inputs_t, targets_t, h0)
    print(f"  - LocalRule loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    bptt = rnn_core.TorchBPTTRNN(
        input_size,
        hidden_size,
        output_size,
        eta=1e-3,
        loss_mode="ce",
        time_normalization=False,
        seed=0,
        device=device,
    )
    loss, _ = bptt.train_batch(inputs_t, targets_t, h0)
    print(f"  - BPTT loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    tbptt = rnn_core.TorchBPTTRNN(
        input_size,
        hidden_size,
        output_size,
        eta=1e-3,
        loss_mode="ce",
        tbptt_steps=5,
        time_normalization=False,
        seed=0,
        device=device,
    )
    loss, _ = tbptt.train_batch(inputs_t, targets_t, h0)
    print(f"  - TBPTT-5 loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    eprop = rnn_core.StandardEPropRNN(
        input_size,
        hidden_size,
        output_size,
        eta=1e-3,
        feedback="symmetric",
        seed=0,
        loss_mode="ce",
        device=device,
    )
    loss, _ = eprop.train_batch(inputs_t, targets_t, h0)
    print(f"  - E-Prop loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    fptt = rnn_core.StrictFPTTClassifier(
        input_size,
        hidden_size,
        output_size,
        eta=1e-3,
        parts=4,
        clip=1.0,
        lmbda=1.0,
        oracle_momentum=1.0,
        label_mode="last",
        oracle_id="sanity",
        use_oracle=False,
        device=device,
    )
    loss, _ = fptt.train_batch(inputs_t, targets_t, h0)
    print(f"  - StrictFPTT loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    driver = inputs_t[0].detach()
    lyap = rnn_core.calculate_lyapunov_exponent_numpy(local, driver)
    print(f"  - Lyapunov (RNN/local)={lyap:.6f}")
    if not _is_finite(lyap):
        return 1

    print("[SANITY] ConvRNN: Local/BPTT/TBPTT + Lyapunov (tiny random batch)")
    b, channels, height, width, output_size, steps = 2, 1, 16, 16, 3, 6
    images = rng.standard_normal((b, channels, height, width), dtype=np.float32)
    labels = rng.integers(0, output_size, size=(b,), dtype=np.int64)
    targets = cnn_core.build_repeated_targets(labels, output_size, steps)

    images_t = torch.as_tensor(images, dtype=torch.float32, device=device)
    targets_t = torch.as_tensor(targets, dtype=torch.float32, device=device)

    conv_local = cnn_core.TorchLocalRuleConvRNN(
        in_channels=channels,
        enc_channels=(4, 8),
        hidden_channels=4,
        output_size=output_size,
        steps=steps,
        eta=1e-3,
        loss_mode="ce",
        train_encoder=False,
        seed=0,
        device=device,
        kernel_size=3,
    )
    loss, _ = conv_local.train_batch(images_t, targets_t, None)
    print(f"  - Conv LocalRule loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    conv_tbptt = cnn_core.TorchBPTTConvRNN(
        in_channels=channels,
        enc_channels=(4, 8),
        hidden_channels=4,
        output_size=output_size,
        steps=steps,
        eta=1e-3,
        loss_mode="ce",
        tbptt_steps=3,
        time_normalization=False,
        train_encoder=False,
        seed=0,
        device=device,
        kernel_size=3,
    )
    loss, _ = conv_tbptt.train_batch(images_t, targets_t, None)
    print(f"  - Conv TBPTT-3 loss={loss:.6f}")
    if not _is_finite(loss):
        return 1

    driver_img = images_t[:1].detach()
    lyap = cnn_core.calculate_lyapunov_exponent_conv(
        conv_local,
        driver_img,
        steps=50,
        num_vectors=1,
    )
    print(f"  - Lyapunov (ConvRNN/local)={lyap:.6f}")
    if not _is_finite(lyap):
        return 1

    print("[SANITY] OK")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

