from __future__ import annotations

import argparse
import math
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, Tuple

import numpy as np

ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(ROOT_DIR))

from task.common.sequence_core import (  # noqa: E402
    TorchLocalRuleRNN,
    build_lyapunov_driver,
    calculate_lyapunov_exponent_numpy,
    generate_lorenz_sequences,
    train_batches,
)


def step_lorenz(state: np.ndarray, dt: float, sigma: float, rho: float, beta: float) -> np.ndarray:
    x, y, z = state
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return state + dt * np.array([dx, dy, dz], dtype=np.float32)


def generate_lorenz_attractor_dataset(
    num_samples: int,
    seq_len: int,
    horizon: int,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
    warmup: int,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    total_len = seq_len + horizon
    inputs = np.zeros((num_samples, 3, seq_len), dtype=np.float32)
    targets = np.zeros((num_samples, 3, seq_len), dtype=np.float32)
    for idx in range(num_samples):
        state = rng.normal(scale=0.5, size=(3,)).astype(np.float32)
        for _ in range(warmup):
            state = step_lorenz(state, dt, sigma, rho, beta)
        states = []
        for _ in range(total_len):
            state = step_lorenz(state, dt, sigma, rho, beta)
            states.append(state.copy())
        series = np.stack(states, axis=0)
        inputs[idx] = series[:seq_len].T
        targets[idx] = series[horizon : horizon + seq_len].T
    return inputs, targets


def compute_lyapunov_diagnostics_numpy(
    model: Any,
    driver_input: np.ndarray,
    seed: int,
    min_steps: int = 50,
) -> Dict[str, Any]:
    def _to_numpy(value: Any) -> np.ndarray:
        if hasattr(value, "detach"):
            return value.detach().cpu().numpy()
        return np.asarray(value)

    W_hh = _to_numpy(getattr(model, "W_hh")).astype(np.float64, copy=False)
    W_xh = _to_numpy(getattr(model, "W_xh")).astype(np.float64, copy=False)
    b_h = _to_numpy(getattr(model, "b_h")).astype(np.float64, copy=False).reshape(-1, 1)
    hidden = int(W_hh.shape[0])

    driver = _to_numpy(driver_input).astype(np.float64, copy=False)
    if driver.ndim == 1:
        driver = driver[:, None]
    time_steps = int(driver.shape[1])
    if time_steps < min_steps:
        reps = int(math.ceil(min_steps / max(time_steps, 1)))
        driver = np.tile(driver, (1, reps))
        time_steps = int(driver.shape[1])

    try:
        rho_w = float(np.max(np.abs(np.linalg.eigvals(W_hh))))
    except Exception:
        rho_w = float("nan")
    try:
        sigma_w = float(np.linalg.svd(W_hh, compute_uv=False)[0])
    except Exception:
        sigma_w = float("nan")

    h = np.zeros((hidden, 1), dtype=np.float64)
    rng = np.random.default_rng(int(seed))
    v = rng.standard_normal((hidden, 1)).astype(np.float64)
    v_norm0 = float(np.linalg.norm(v))
    if not np.isfinite(v_norm0) or v_norm0 <= 0.0:
        v = np.ones((hidden, 1), dtype=np.float64)
        v_norm0 = float(np.linalg.norm(v))
    v = v / max(v_norm0, 1e-12)

    log_growth = []
    phi_max = []
    phi_mean = []
    sat_frac = []
    eps = 1e-12

    for t in range(time_steps):
        I_t = driver[:, t].reshape(-1, 1)
        x = W_hh @ h + W_xh @ I_t + b_h
        h = np.tanh(x)
        phi = 1.0 - h * h

        phi_max.append(float(np.max(phi)))
        phi_mean.append(float(np.mean(phi)))
        sat_frac.append(float(np.mean(np.abs(h) >= 0.95)))

        J = phi * W_hh
        v_next = J @ v
        n = float(np.linalg.norm(v_next))
        log_growth.append(float(np.log(max(n, eps))))
        v = v_next / max(n, eps)

    lyap_pi = float(np.mean(log_growth)) if log_growth else float("nan")
    lyap_qr = float(calculate_lyapunov_exponent_numpy(model, driver))

    phi_max_arr = np.asarray(phi_max, dtype=np.float64)
    log_phi_max_mean = (
        float(np.mean(np.log(np.clip(phi_max_arr, eps, None)))) if phi_max_arr.size else float("nan")
    )
    upper_bound = (
        log_phi_max_mean + float(np.log(max(sigma_w, eps))) if np.isfinite(log_phi_max_mean) else float("nan")
    )

    return {
        "lyap_qr": lyap_qr,
        "lyap_power_iter": lyap_pi,
        "contraction_factor": float(np.exp(lyap_pi)) if np.isfinite(lyap_pi) else float("nan"),
        "rho_W_hh": rho_w,
        "sigma_W_hh": sigma_w,
        "phi_max": np.asarray(phi_max, dtype=np.float64),
        "phi_mean": np.asarray(phi_mean, dtype=np.float64),
        "sat_frac": np.asarray(sat_frac, dtype=np.float64),
        "log_growth": np.asarray(log_growth, dtype=np.float64),
        "upper_bound_log_phiMax_plus_log_sigmaW": upper_bound,
    }


def summarize_diag(tag: str, diag: Dict[str, Any]) -> None:
    eps = 1e-12
    phi_mean_med = float(np.median(diag["phi_mean"])) if diag["phi_mean"].size else float("nan")
    phi_max_med = float(np.median(diag["phi_max"])) if diag["phi_max"].size else float("nan")
    sat_med = float(np.median(diag["sat_frac"])) if diag["sat_frac"].size else float("nan")
    log_growth_mean = float(np.mean(diag["log_growth"])) if diag["log_growth"].size else float("nan")
    log_growth_std = float(np.std(diag["log_growth"])) if diag["log_growth"].size else float("nan")
    log_phi_mean = (
        float(np.mean(np.log(np.clip(diag["phi_mean"], eps, None))))
        if diag["phi_mean"].size
        else float("nan")
    )
    log_sigma = math.log(max(diag["sigma_W_hh"], eps)) if np.isfinite(diag["sigma_W_hh"]) else float("nan")
    approx = log_phi_mean + log_sigma if np.isfinite(log_phi_mean) and np.isfinite(log_sigma) else float("nan")
    eff_gain = diag["sigma_W_hh"] * phi_mean_med if np.isfinite(diag["sigma_W_hh"]) else float("nan")

    print(
        f"  {tag} lyap_qr={diag['lyap_qr']:.4f} lyap_pi={diag['lyap_power_iter']:.4f} "
        f"exp(lyap)={diag['contraction_factor']:.4f} log_growth={log_growth_mean:.4f}±{log_growth_std:.4f}"
    )
    print(
        f"  {tag} rho_W={diag['rho_W_hh']:.4f} sigma_W={diag['sigma_W_hh']:.4f} "
        f"phi_mean_med={phi_mean_med:.4f} phi_max_med={phi_max_med:.4f} sat_med={sat_med:.4f}"
    )
    print(
        f"  {tag} log_phi_mean+log_sigma={approx:.4f} upper_bound={diag['upper_bound_log_phiMax_plus_log_sigmaW']:.4f} "
        f"eff_gain_med={eff_gain:.4f}"
    )


def parse_gains(text: str) -> Iterable[float]:
    gains = []
    for item in text.split(","):
        item = item.strip()
        if not item:
            continue
        gains.append(float(item))
    return gains


def run_diagnosis(
    task_label: str,
    train_inputs: np.ndarray,
    train_targets: np.ndarray,
    test_inputs: np.ndarray,
    gains: Iterable[float],
    hidden: int,
    epochs: int,
    batch_size: int,
    lr: float,
    seed: int,
) -> None:
    input_size = int(train_inputs.shape[1])
    output_size = int(train_targets.shape[1])
    driver = build_lyapunov_driver(test_inputs)

    print(f"\n=== {task_label} ===")
    for g in gains:
        model = TorchLocalRuleRNN(
            input_size,
            hidden,
            output_size,
            eta=lr,
            loss_mode="mse",
            seed=seed,
        )
        model.initialize_weights_with_gain(float(g), seed=seed)
        pre_diag = compute_lyapunov_diagnostics_numpy(model, driver, seed)
        train_batches(model, train_inputs, train_targets, batch_size, epochs, seed)
        post_diag = compute_lyapunov_diagnostics_numpy(model, driver, seed)

        print(f"g={float(g):.3f}")
        summarize_diag("pre ", pre_diag)
        summarize_diag("post", post_diag)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", choices=["lorenz_attractor", "lorenz_image", "both"], default="both")
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--gains", type=str, default="0.5,1.0,1.5")

    parser.add_argument("--train-samples", type=int, default=512)
    parser.add_argument("--test-samples", type=int, default=128)
    parser.add_argument("--seq-len", type=int, default=200)
    parser.add_argument("--pred-horizon", type=int, default=3)
    parser.add_argument("--dt", type=float, default=0.01)
    parser.add_argument("--sigma", type=float, default=10.0)
    parser.add_argument("--rho", type=float, default=28.0)
    parser.add_argument("--beta", type=float, default=2.6667)
    parser.add_argument("--warmup", type=int, default=100)

    parser.add_argument("--seq-len-image", type=int, default=30)
    parser.add_argument("--frame-h", type=int, default=16)
    parser.add_argument("--frame-w", type=int, default=16)
    parser.add_argument("--blur-sigma", type=float, default=1.2)

    args = parser.parse_args()
    gains = list(parse_gains(args.gains))
    if not gains:
        gains = [0.5, 1.0, 1.5]

    if args.task in {"lorenz_attractor", "both"}:
        total = int(args.train_samples) + int(args.test_samples)
        inputs, targets = generate_lorenz_attractor_dataset(
            num_samples=total,
            seq_len=int(args.seq_len),
            horizon=max(1, int(args.pred_horizon)),
            dt=float(args.dt),
            sigma=float(args.sigma),
            rho=float(args.rho),
            beta=float(args.beta),
            warmup=int(args.warmup),
            seed=int(args.seed),
        )
        train_inputs = inputs[: args.train_samples]
        train_targets = targets[: args.train_samples]
        test_inputs = inputs[args.train_samples :]
        run_diagnosis(
            "Lorenz Attractor",
            train_inputs,
            train_targets,
            test_inputs,
            gains,
            args.hidden,
            args.epochs,
            args.batch_size,
            args.lr,
            args.seed,
        )

    if args.task in {"lorenz_image", "both"}:
        total = int(args.train_samples) + int(args.test_samples)
        inputs, targets = generate_lorenz_sequences(
            num_samples=total,
            seq_len=int(args.seq_len_image),
            frame_h=int(args.frame_h),
            frame_w=int(args.frame_w),
            dt=float(args.dt),
            sigma=float(args.sigma),
            rho=float(args.rho),
            beta=float(args.beta),
            warmup=int(args.warmup),
            seed=int(args.seed),
            blur_sigma=float(args.blur_sigma),
        )
        train_inputs = inputs[: args.train_samples]
        train_targets = targets[: args.train_samples]
        test_inputs = inputs[args.train_samples :]
        run_diagnosis(
            "Lorenz Image",
            train_inputs,
            train_targets,
            test_inputs,
            gains,
            args.hidden,
            args.epochs,
            args.batch_size,
            args.lr,
            args.seed,
        )


if __name__ == "__main__":
    main()
