from __future__ import annotations

import math
import sys
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.shared_rnn_utils import initialize_rnn_parameters
from task.common.sequence_core import TorchBPTTRNN, calculate_lyapunov_exponent_numpy, load_params


def set_seed(seed: int) -> None:
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))


class _DummyRNN:
    def __init__(self, params: Dict[str, np.ndarray]) -> None:
        self.device = torch.device("cpu")
        self.hidden_size = int(params["W_hh"].shape[0])
        self.W_xh = params["W_xh"]
        self.W_hh = params["W_hh"]
        self.b_h = params["b_h"]
        self.W_hy = params["W_hy"]
        self.b_y = params["b_y"]


def _lyapunov_reference_numpy(model: object, driver_input: np.ndarray) -> float:
    n_hidden = int(getattr(model, "hidden_size"))
    W_hh = np.asarray(getattr(model, "W_hh"), dtype=np.float64)
    W_xh = np.asarray(getattr(model, "W_xh"), dtype=np.float64)
    b_h = np.asarray(getattr(model, "b_h"), dtype=np.float64).reshape(-1, 1)

    h = np.zeros((n_hidden, 1), dtype=np.float64)
    Q = np.eye(n_hidden, dtype=np.float64)
    log_r_diag_sum = np.zeros(n_hidden, dtype=np.float64)
    log_floor = 1e-12

    driver = np.asarray(driver_input, dtype=np.float64)
    if driver.ndim == 1:
        driver = driver[:, None]
    time_steps = int(driver.shape[1])
    for t in range(time_steps):
        I_t = driver[:, t].reshape(-1, 1)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = np.tanh(x_t)
        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = phi_prime[:, None] * W_hh

        Z = jacobian @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")
        r_diag_abs = np.abs(np.diag(R))
        log_r_diag_sum = log_r_diag_sum + np.log(np.clip(r_diag_abs, log_floor, None))
        h = h_next

    return float(np.max(log_r_diag_sum / max(time_steps, 1)))


def _one_hot(labels: np.ndarray, num_classes: int) -> np.ndarray:
    labels = labels.astype(np.int64)
    eye = np.eye(int(num_classes), dtype=np.float32)
    return eye[labels]


def _torch_params_from_np(params_np, device: torch.device) -> Dict[str, torch.Tensor]:
    return {
        "W_xh": torch.as_tensor(params_np.W_xh, dtype=torch.float32, device=device),
        "W_hh": torch.as_tensor(params_np.W_hh, dtype=torch.float32, device=device),
        "b_h": torch.as_tensor(params_np.b_h, dtype=torch.float32, device=device),
        "W_hy": torch.as_tensor(params_np.W_hy, dtype=torch.float32, device=device),
        "b_y": torch.as_tensor(params_np.b_y, dtype=torch.float32, device=device),
    }


def _param_max_abs_diff(a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]) -> float:
    return float(max(float(np.max(np.abs(a[k] - b[k]))) for k in a.keys()))


def _forward_and_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params: Dict[str, torch.Tensor],
    h_prev: torch.Tensor,
    *,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    start: int = 0,
    end: int | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    time_steps = int(inputs.shape[2])
    end_idx = time_steps if end is None else int(end)
    total_loss = torch.zeros((), device=inputs.device)
    h_state = h_prev
    for t in range(int(start), end_idx):
        I_t = inputs[:, :, t].T
        x_t = params["W_hh"] @ h_state + params["W_xh"] @ I_t + params["b_h"]
        h_state = torch.tanh(x_t)
        y_hat = params["W_hy"] @ h_state + params["b_y"]
        y_true = targets[:, :, t].T

        if loss_mode == "ce":
            log_probs = torch.log_softmax(y_hat, dim=0)
            loss_t = -(y_true * log_probs).sum(dim=0).mean()
        else:
            error = y_hat - y_true
            loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

        if step_weights is not None:
            loss_t = loss_t * step_weights[t]
        total_loss = total_loss + loss_t
    return total_loss, h_state


def _reference_full_bptt_sgd(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    h_prev: torch.Tensor,
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    time_normalization: bool,
    clip_norm: float,
) -> Tuple[Dict[str, torch.Tensor], float, float]:
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    time_steps = int(inputs.shape[2])

    if step_weights is not None:
        step_weights_t = step_weights.to(dtype=torch.float32)
        weight_norm = float(step_weights_t.sum().item())
    else:
        step_weights_t = None
        weight_norm = float(time_steps)
    weight_norm = max(weight_norm, 1.0)

    total_loss, _ = _forward_and_loss(
        inputs,
        targets,
        params,
        h_prev,
        loss_mode=loss_mode,
        step_weights=step_weights_t,
    )
    loss_scale = weight_norm if time_normalization else 1.0
    scaled_loss = total_loss / max(loss_scale, 1.0)
    scaled_loss.backward()

    grads = [p.grad for p in params.values()]
    if any(g is None for g in grads):
        raise RuntimeError("Missing gradient in reference_full_bptt_sgd.")
    total_norm_sq = torch.zeros((), device=inputs.device)
    for g in grads:
        total_norm_sq = total_norm_sq + torch.sum(g**2)  # type: ignore[operator]
    grad_norm = float(torch.sqrt(total_norm_sq).item())

    if clip_norm > 0:
        torch.nn.utils.clip_grad_norm_(list(params.values()), clip_norm)

    updated = {k: (v - float(eta) * v.grad).detach() for k, v in params.items()}  # type: ignore[operator]
    return updated, float(scaled_loss.detach().item()), grad_norm


def _reference_tbptt_sgd(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    h_prev: torch.Tensor,
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    tbptt_steps: int,
    time_normalization: bool,
) -> Tuple[Dict[str, torch.Tensor], float]:
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    time_steps = int(inputs.shape[2])
    tbptt_steps = max(1, int(tbptt_steps))

    if step_weights is not None:
        step_weights_t = step_weights.to(dtype=torch.float32)
    else:
        step_weights_t = None

    total_loss_value = 0.0
    h_state = h_prev.detach()
    for start in range(0, time_steps, tbptt_steps):
        end = min(time_steps, start + tbptt_steps)
        h_state = h_state.detach()
        chunk_loss, h_state_next = _forward_and_loss(
            inputs,
            targets,
            params,
            h_state,
            loss_mode=loss_mode,
            step_weights=step_weights_t,
            start=start,
            end=end,
        )
        if time_normalization:
            if step_weights_t is None:
                chunk_scale = float(end - start)
            else:
                chunk_scale = float(step_weights_t[start:end].sum().item())
            chunk_scale = max(chunk_scale, 1.0)
        else:
            chunk_scale = 1.0
        scaled_chunk_loss = chunk_loss / max(chunk_scale, 1.0)
        scaled_chunk_loss.backward()

        with torch.no_grad():
            for p in params.values():
                if p.grad is None:
                    raise RuntimeError("Missing gradient in reference_tbptt_sgd.")
                p.add_(-float(eta) * p.grad)
                p.grad = None
        total_loss_value += float(chunk_loss.detach().item())
        h_state = h_state_next.detach()

    avg_loss_scale = float(inputs.shape[2]) if time_normalization else 1.0
    avg_loss = total_loss_value / max(avg_loss_scale, 1.0)
    updated = {k: v.detach() for k, v in params.items()}
    return updated, avg_loss


def test_lyapunov_consistency() -> None:
    print("[TEST] Lyapunov torch(QR) vs numpy(QR) consistency ...")
    set_seed(0)
    input_size, hidden_size, output_size = 3, 7, 2
    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.3, seed=123)
    params_dict = {k: getattr(params_np, k) for k in ("W_xh", "W_hh", "b_h", "W_hy", "b_y")}
    model = _DummyRNN(params_dict)

    rng = np.random.default_rng(999)
    driver = (0.2 * rng.standard_normal((input_size, 50))).astype(np.float32)

    lyap_torch = float(calculate_lyapunov_exponent_numpy(model, driver))
    lyap_np = float(_lyapunov_reference_numpy(model, driver))
    diff = abs(lyap_torch - lyap_np)
    print(f"  diff={diff:.3e} torch={lyap_torch:.8f} numpy={lyap_np:.8f}")
    assert diff < 1e-6


def test_bptt_clipping_update() -> None:
    print("[TEST] BPTT gradient clipping (on/off) matches autograd reference ...")
    set_seed(1)
    device = torch.device("cpu")

    batch_size, input_size, hidden_size, output_size, time_steps = 4, 3, 6, 3, 7
    eta = 1e-1
    loss_mode = "ce"

    rng = np.random.default_rng(2024)
    inputs = torch.as_tensor(rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32), device=device)
    labels = rng.integers(0, output_size, size=(batch_size,), dtype=np.int64)
    targets = np.repeat(_one_hot(labels, output_size)[:, :, None], time_steps, axis=2)
    targets_t = torch.as_tensor(targets, dtype=torch.float32, device=device)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=42)
    params_init_t = _torch_params_from_np(params_np, device)

    # Case A: time_normalization=True, clip_norm = max_grad_norm
    updated_unref, _, grad_norm = _reference_full_bptt_sgd(
        inputs,
        targets_t,
        params_init_t,
        h_prev,
        eta=eta,
        loss_mode=loss_mode,
        step_weights=None,
        time_normalization=True,
        clip_norm=0.0,
    )
    if not math.isfinite(grad_norm) or grad_norm <= 0:
        raise RuntimeError(f"Unexpected grad_norm={grad_norm}")
    max_grad_norm = grad_norm / 2.0
    updated_ref, _, _ = _reference_full_bptt_sgd(
        inputs,
        targets_t,
        params_init_t,
        h_prev,
        eta=eta,
        loss_mode=loss_mode,
        step_weights=None,
        time_normalization=True,
        clip_norm=max_grad_norm,
    )

    model = TorchBPTTRNN(
        input_size,
        hidden_size,
        output_size,
        eta=eta,
        loss_mode=loss_mode,
        max_grad_norm=max_grad_norm,
        time_normalization=True,
        seed=0,
        device=device,
    )
    load_params(model, {k: getattr(params_np, k) for k in ("W_xh", "W_hh", "b_h", "W_hy", "b_y")})
    model.train_batch(inputs, targets_t, h_prev)
    actual = {
        "W_xh": model.W_xh.detach().cpu().numpy(),
        "W_hh": model.W_hh.detach().cpu().numpy(),
        "b_h": model.b_h.detach().cpu().numpy(),
        "W_hy": model.W_hy.detach().cpu().numpy(),
        "b_y": model.b_y.detach().cpu().numpy(),
    }
    expected = {k: v.detach().cpu().numpy() for k, v in updated_ref.items()}
    diff = _param_max_abs_diff(actual, expected)
    print(f"  time_norm=True clip: max abs diff={diff:.3e} (max_grad_norm={max_grad_norm:.3e})")
    assert diff < 1e-6

    # Case B: time_normalization=False, clip_norm = max_grad_norm * sqrt(T)
    # Pick max_grad_norm so that effective clip_norm is about half the (un-normalized) grad norm.
    _, _, grad_norm_unnorm = _reference_full_bptt_sgd(
        inputs,
        targets_t,
        params_init_t,
        h_prev,
        eta=eta,
        loss_mode=loss_mode,
        step_weights=None,
        time_normalization=False,
        clip_norm=0.0,
    )
    weight_norm = float(time_steps)
    clip_norm_eff = grad_norm_unnorm / 2.0
    max_grad_norm = clip_norm_eff / math.sqrt(weight_norm)
    updated_ref, _, _ = _reference_full_bptt_sgd(
        inputs,
        targets_t,
        params_init_t,
        h_prev,
        eta=eta,
        loss_mode=loss_mode,
        step_weights=None,
        time_normalization=False,
        clip_norm=clip_norm_eff,
    )

    model = TorchBPTTRNN(
        input_size,
        hidden_size,
        output_size,
        eta=eta,
        loss_mode=loss_mode,
        max_grad_norm=max_grad_norm,
        time_normalization=False,
        seed=0,
        device=device,
    )
    load_params(model, {k: getattr(params_np, k) for k in ("W_xh", "W_hh", "b_h", "W_hy", "b_y")})
    model.train_batch(inputs, targets_t, h_prev)
    actual = {
        "W_xh": model.W_xh.detach().cpu().numpy(),
        "W_hh": model.W_hh.detach().cpu().numpy(),
        "b_h": model.b_h.detach().cpu().numpy(),
        "W_hy": model.W_hy.detach().cpu().numpy(),
        "b_y": model.b_y.detach().cpu().numpy(),
    }
    expected = {k: v.detach().cpu().numpy() for k, v in updated_ref.items()}
    diff = _param_max_abs_diff(actual, expected)
    print(
        "  time_norm=False clip: max abs diff="
        f"{diff:.3e} (max_grad_norm={max_grad_norm:.3e}, clip_norm_eff={clip_norm_eff:.3e})"
    )
    assert diff < 1e-6


def test_tbptt_detach_matches_reference() -> None:
    print("[TEST] TBPTT(chunk detach) matches autograd reference ...")
    set_seed(2)
    device = torch.device("cpu")

    batch_size, input_size, hidden_size, output_size, time_steps = 3, 4, 5, 3, 6
    tbptt_steps = 2
    eta = 5e-2
    loss_mode = "ce"

    rng = np.random.default_rng(7)
    inputs = torch.as_tensor(rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32), device=device)
    labels = rng.integers(0, output_size, size=(batch_size,), dtype=np.int64)
    targets = np.repeat(_one_hot(labels, output_size)[:, :, None], time_steps, axis=2)
    targets_t = torch.as_tensor(targets, dtype=torch.float32, device=device)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=99)
    params_init_t = _torch_params_from_np(params_np, device)
    updated_ref, _ = _reference_tbptt_sgd(
        inputs,
        targets_t,
        params_init_t,
        h_prev,
        eta=eta,
        loss_mode=loss_mode,
        step_weights=None,
        tbptt_steps=tbptt_steps,
        time_normalization=False,
    )

    model = TorchBPTTRNN(
        input_size,
        hidden_size,
        output_size,
        eta=eta,
        loss_mode=loss_mode,
        max_grad_norm=0.0,
        tbptt_steps=tbptt_steps,
        time_normalization=False,
        seed=0,
        device=device,
    )
    load_params(model, {k: getattr(params_np, k) for k in ("W_xh", "W_hh", "b_h", "W_hy", "b_y")})
    model.train_batch(inputs, targets_t, h_prev)
    actual = {
        "W_xh": model.W_xh.detach().cpu().numpy(),
        "W_hh": model.W_hh.detach().cpu().numpy(),
        "b_h": model.b_h.detach().cpu().numpy(),
        "W_hy": model.W_hy.detach().cpu().numpy(),
        "b_y": model.b_y.detach().cpu().numpy(),
    }
    expected = {k: v.detach().cpu().numpy() for k, v in updated_ref.items()}
    diff = _param_max_abs_diff(actual, expected)
    print(f"  TBPTT param max abs diff = {diff:.3e}")
    assert diff < 1e-6


def main() -> None:
    test_lyapunov_consistency()
    test_bptt_clipping_update()
    test_tbptt_detach_matches_reference()
    print("OK")


if __name__ == "__main__":
    main()

