from __future__ import annotations

import argparse
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import numpy as np
import torch

# Headless plotting (terminal/CI friendly)
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.shared_rnn_utils import RNNParameters, initialize_rnn_parameters
from methods.standard_eprop import StandardEPropRNN
from methods.strict_fptt import (
    ClassOracleBuffer,
    OracleBufferStore,
    StrictFPTTClassifier,
    build_chunk_schedule,
)
from task.common.sequence_core import TorchBPTTRNN, TorchLocalRuleRNN


def set_seed(seed: int) -> None:
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))


def _softmax_cols(logits: torch.Tensor) -> torch.Tensor:
    logits_shifted = logits - torch.max(logits, dim=0, keepdim=True).values
    exp_logits = torch.exp(logits_shifted)
    return exp_logits / (torch.sum(exp_logits, dim=0, keepdim=True) + 1e-12)


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _torch_params_from_rnnparams(
    params: RNNParameters, *, device: torch.device
) -> Dict[str, torch.Tensor]:
    return {
        "W_xh": torch.tensor(params.W_xh, dtype=torch.float32, device=device),
        "W_hh": torch.tensor(params.W_hh, dtype=torch.float32, device=device),
        "b_h": torch.tensor(params.b_h, dtype=torch.float32, device=device),
        "W_hy": torch.tensor(params.W_hy, dtype=torch.float32, device=device),
        "b_y": torch.tensor(params.b_y, dtype=torch.float32, device=device),
    }


def _rnnparams_from_torch_params(params: Dict[str, torch.Tensor]) -> RNNParameters:
    return RNNParameters(
        W_xh=params["W_xh"].detach().cpu().numpy().astype(np.float32).copy(),
        W_hh=params["W_hh"].detach().cpu().numpy().astype(np.float32).copy(),
        b_h=params["b_h"].detach().cpu().numpy().astype(np.float32).copy(),
        W_hy=params["W_hy"].detach().cpu().numpy().astype(np.float32).copy(),
        b_y=params["b_y"].detach().cpu().numpy().astype(np.float32).copy(),
    )


def _extract_rnn_parameters_any(model: object) -> RNNParameters:
    """
    Like methods.shared_rnn_utils.extract_rnn_parameters, but works with torch.nn.Parameter
    (requires_grad=True) by detaching first.
    """

    def to_np(value) -> np.ndarray:
        if torch.is_tensor(value):
            return value.detach().cpu().numpy().astype(np.float32, copy=False)
        return np.asarray(value, dtype=np.float32)

    return RNNParameters(
        W_xh=to_np(getattr(model, "W_xh")).copy(),
        W_hh=to_np(getattr(model, "W_hh")).copy(),
        b_h=to_np(getattr(model, "b_h")).copy(),
        W_hy=to_np(getattr(model, "W_hy")).copy(),
        b_y=to_np(getattr(model, "b_y")).copy(),
    )


def _param_max_abs_diff(a: RNNParameters, b: RNNParameters) -> float:
    diffs = [
        float(np.max(np.abs(a.W_xh - b.W_xh))),
        float(np.max(np.abs(a.W_hh - b.W_hh))),
        float(np.max(np.abs(a.b_h - b.b_h))),
        float(np.max(np.abs(a.W_hy - b.W_hy))),
        float(np.max(np.abs(a.b_y - b.b_y))),
    ]
    return float(max(diffs))


def _copy_params_into_bptt(model: TorchBPTTRNN, params: RNNParameters) -> None:
    device = model.device
    with torch.no_grad():
        model.W_xh.copy_(torch.as_tensor(params.W_xh, dtype=torch.float32, device=device))
        model.W_hh.copy_(torch.as_tensor(params.W_hh, dtype=torch.float32, device=device))
        model.b_h.copy_(torch.as_tensor(params.b_h, dtype=torch.float32, device=device))
        model.W_hy.copy_(torch.as_tensor(params.W_hy, dtype=torch.float32, device=device))
        model.b_y.copy_(torch.as_tensor(params.b_y, dtype=torch.float32, device=device))


def _copy_params_into_local(model: TorchLocalRuleRNN, params: RNNParameters) -> None:
    device = model.device
    model.W_xh = torch.tensor(params.W_xh, dtype=torch.float32, device=device)
    model.W_hh = torch.tensor(params.W_hh, dtype=torch.float32, device=device)
    model.b_h = torch.tensor(params.b_h, dtype=torch.float32, device=device)
    model.W_hy = torch.tensor(params.W_hy, dtype=torch.float32, device=device)
    model.b_y = torch.tensor(params.b_y, dtype=torch.float32, device=device)


def _copy_params_into_eprop(model: StandardEPropRNN, params: RNNParameters) -> None:
    device = model.device
    model.W_xh = torch.tensor(params.W_xh, dtype=torch.float32, device=device)
    model.W_hh = torch.tensor(params.W_hh, dtype=torch.float32, device=device)
    model.b_h = torch.tensor(params.b_h, dtype=torch.float32, device=device)
    model.W_hy = torch.tensor(params.W_hy, dtype=torch.float32, device=device)
    model.b_y = torch.tensor(params.b_y, dtype=torch.float32, device=device)


def _make_last_step_classification_batch(
    *,
    batch_size: int,
    input_size: int,
    time_steps: int,
    num_classes: int,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(int(seed))
    inputs = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    score = inputs[:, 0, :].sum(axis=1)
    labels = (score > 0).astype(np.int64) % num_classes

    targets = np.zeros((batch_size, num_classes, time_steps), dtype=np.float32)
    targets[np.arange(batch_size), labels, -1] = 1.0

    step_weights = np.zeros((time_steps,), dtype=np.float32)
    step_weights[-1] = 1.0
    return inputs, targets, labels, step_weights


def _plot_lines(
    series: Dict[str, List[float]],
    *,
    title: str,
    xlabel: str,
    ylabel: str,
    out_path: Path,
) -> None:
    _ensure_dir(out_path.parent)
    fig = plt.figure(figsize=(8.6, 4.8), constrained_layout=True)
    ax = fig.add_subplot(111)
    for name, values in series.items():
        ax.plot(np.arange(1, len(values) + 1), values, linewidth=2.0, label=name)
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, alpha=0.25)
    ax.legend(loc="best", fontsize=9)
    fig.savefig(out_path, dpi=160)
    plt.close(fig)


@dataclass
class ValidationArtifacts:
    out_dir: Path
    plots_dir: Path


def _prepare_artifacts(*, out_root: Path | None = None) -> ValidationArtifacts:
    base = Path(out_root) if out_root is not None else (Path(__file__).resolve().parent / "plots")
    run_id = time.strftime("validate_%Y%m%d_%H%M%S")
    out_dir = base / run_id
    plots_dir = out_dir
    _ensure_dir(plots_dir)
    return ValidationArtifacts(out_dir=out_dir, plots_dir=plots_dir)


def _reference_full_bptt_sgd(
    inputs: torch.Tensor,  # (B, F, T)
    targets: torch.Tensor,  # (B, C, T)
    params_init: Dict[str, torch.Tensor],  # b_* are (., 1)
    h_prev: torch.Tensor,  # (H, B)
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    time_normalization: bool,
) -> Tuple[Dict[str, torch.Tensor], float]:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    device = inputs.device
    params = {k: v.clone().detach().to(device).requires_grad_(True) for k, v in params_init.items()}

    time_steps = int(inputs.shape[2])
    if step_weights is not None:
        step_weights_t = step_weights.to(device=device, dtype=torch.float32)
        weight_norm = float(step_weights_t.sum().item())
    else:
        step_weights_t = None
        weight_norm = float(time_steps)
    weight_norm = max(weight_norm, 1.0)

    h_state = h_prev.clone().detach().to(device)
    total_loss = torch.zeros((), device=device)

    for t in range(time_steps):
        step_weight = step_weights_t[t] if step_weights_t is not None else 1.0
        I_t = inputs[:, :, t].T
        x_t = params["W_hh"] @ h_state + params["W_xh"] @ I_t + params["b_h"]
        h_state = torch.tanh(x_t)
        y_hat_t = params["W_hy"] @ h_state + params["b_y"]
        y_true_t = targets[:, :, t].T

        if loss_mode == "ce":
            log_probs = torch.log_softmax(y_hat_t, dim=0)
            loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
        else:
            error = y_hat_t - y_true_t
            loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

        if step_weights_t is not None:
            loss_t = loss_t * step_weight
        total_loss = total_loss + loss_t

    loss_scale = weight_norm if time_normalization else 1.0
    scaled_loss = total_loss / max(loss_scale, 1.0)
    grads = torch.autograd.grad(scaled_loss, tuple(params.values()))

    with torch.no_grad():
        updated = {}
        for (name, value), grad in zip(params.items(), grads):
            updated[name] = (value - float(eta) * grad).detach().clone()

    return updated, float(scaled_loss.detach().cpu().item())


def _reference_tbptt_sgd(
    inputs: torch.Tensor,  # (B, F, T)
    targets: torch.Tensor,  # (B, C, T)
    params_init: Dict[str, torch.Tensor],  # b_* are (., 1)
    h_prev: torch.Tensor,  # (H, B)
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    tbptt_steps: int,
    time_normalization: bool,
) -> Tuple[Dict[str, torch.Tensor], float]:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    device = inputs.device
    params = {k: v.clone().detach().to(device).requires_grad_(True) for k, v in params_init.items()}

    time_steps = int(inputs.shape[2])
    tbptt_steps = int(tbptt_steps)
    if tbptt_steps <= 0:
        raise ValueError("tbptt_steps must be positive.")

    if step_weights is not None:
        step_weights_t = step_weights.to(device=device, dtype=torch.float32)
        weight_norm = float(step_weights_t.sum().item())
    else:
        step_weights_t = None
        weight_norm = float(time_steps)
    weight_norm = max(weight_norm, 1.0)

    total_loss_value = 0.0
    h_state = h_prev.clone().detach().to(device)

    for start in range(0, time_steps, tbptt_steps):
        end = min(time_steps, start + tbptt_steps)
        chunk_loss = torch.zeros((), device=device)

        for t in range(start, end):
            step_weight = step_weights_t[t] if step_weights_t is not None else 1.0
            I_t = inputs[:, :, t].T
            x_t = params["W_hh"] @ h_state + params["W_xh"] @ I_t + params["b_h"]
            h_state = torch.tanh(x_t)
            y_hat_t = params["W_hy"] @ h_state + params["b_y"]
            y_true_t = targets[:, :, t].T

            if loss_mode == "ce":
                log_probs = torch.log_softmax(y_hat_t, dim=0)
                loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
            else:
                error = y_hat_t - y_true_t
                loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

            if step_weights_t is not None:
                loss_t = loss_t * step_weight
            chunk_loss = chunk_loss + loss_t

        if time_normalization:
            if step_weights_t is not None:
                chunk_weight_norm = float(step_weights_t[start:end].sum().item())
            else:
                chunk_weight_norm = float(end - start)
            chunk_weight_norm = max(chunk_weight_norm, 1.0)
        else:
            chunk_weight_norm = 1.0

        scaled_chunk_loss = chunk_loss / max(chunk_weight_norm, 1.0)
        grads = torch.autograd.grad(scaled_chunk_loss, tuple(params.values()))
        with torch.no_grad():
            new_params = {}
            for (name, value), grad in zip(params.items(), grads):
                new_params[name] = (value - float(eta) * grad).detach().requires_grad_(True)
            params = new_params
        total_loss_value += float(chunk_loss.detach().cpu().item())
        h_state = h_state.detach()

    avg_loss_scale = weight_norm if time_normalization else 1.0
    avg_loss = total_loss_value / max(avg_loss_scale, 1.0)
    updated = {k: v.detach().clone() for k, v in params.items()}
    return updated, float(avg_loss)


def _reference_online_detach_sgd(
    inputs: torch.Tensor,  # (B, F, T)
    targets: torch.Tensor,  # (B, C, T)
    params_init: Dict[str, torch.Tensor],  # b_* are (., 1)
    h_prev: torch.Tensor,  # (H, B)
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
) -> Tuple[Dict[str, torch.Tensor], float]:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    device = inputs.device
    params = {k: v.clone().detach().to(device).requires_grad_(True) for k, v in params_init.items()}
    time_steps = int(inputs.shape[2])

    if step_weights is not None:
        step_weights_t = step_weights.to(device=device, dtype=torch.float32)
        weight_norm = float(step_weights_t.sum().item())
    else:
        step_weights_t = None
        weight_norm = float(time_steps)
    weight_norm = max(weight_norm, 1.0)

    total_loss = torch.zeros((), device=device)
    h_state = h_prev.clone().detach().to(device)
    eps = 1e-12

    for t in range(time_steps):
        step_weight = step_weights_t[t] if step_weights_t is not None else 1.0
        I_t = inputs[:, :, t].T
        y_true_t = targets[:, :, t].T

        x_t = params["W_hh"] @ h_state + params["W_xh"] @ I_t + params["b_h"]
        h_t = torch.tanh(x_t)
        y_hat_t = params["W_hy"] @ h_t + params["b_y"]

        if loss_mode == "ce":
            probs = _softmax_cols(y_hat_t)
            loss_t = step_weight * -torch.mean(torch.sum(y_true_t * torch.log(probs + eps), dim=0))
        else:
            error = y_hat_t - y_true_t
            loss_t = step_weight * 0.5 * torch.mean(torch.sum(error**2, dim=0))

        total_loss = total_loss + loss_t

        grads = torch.autograd.grad(loss_t, tuple(params.values()))
        with torch.no_grad():
            new_params = {}
            for (name, value), grad in zip(params.items(), grads):
                new_params[name] = (value - float(eta) * grad).detach().requires_grad_(True)
            params = new_params
        h_state = h_t.detach()

    avg_loss = float((total_loss / weight_norm).detach().cpu().item())
    updated = {k: v.detach().clone() for k, v in params.items()}
    return updated, avg_loss


def _reference_fptt_classifier_chunk_sgd(
    inputs: torch.Tensor,  # (B, F, T)
    targets: torch.Tensor,  # (B, C, T)
    params_init: Dict[str, torch.Tensor],  # b_* are 1D
    *,
    eta: float,
    parts: int,
) -> Tuple[Dict[str, torch.Tensor], float]:
    device = inputs.device
    params = {k: v.clone().detach().to(device).requires_grad_(True) for k, v in params_init.items()}
    batch_size, _, time_steps = inputs.shape
    schedule = build_chunk_schedule(time_steps, parts)
    if not schedule:
        return {k: v.detach().clone() for k, v in params.items()}, 0.0

    inputs_t = inputs.permute(0, 2, 1)
    labels = targets[:, :, -1].to(device=device, dtype=torch.float32)
    state = torch.zeros((batch_size, params["W_hh"].shape[0]), dtype=torch.float32, device=device)

    total_loss = 0.0
    for start, end in schedule:
        chunk_inputs = inputs_t[:, start:end, :]
        for t in range(chunk_inputs.size(1)):
            state = torch.tanh(
                torch.matmul(state, params["W_hh"].t())
                + torch.matmul(chunk_inputs[:, t, :], params["W_xh"].t())
                + params["b_h"]
            )
        logits = torch.matmul(state, params["W_hy"].t()) + params["b_y"]
        log_probs = torch.log_softmax(logits, dim=1)
        loss = torch.sum(-labels * log_probs, dim=1).mean()
        grads = torch.autograd.grad(loss, tuple(params.values()))
        with torch.no_grad():
            new_params = {}
            for (name, value), grad in zip(params.items(), grads):
                new_params[name] = (value - float(eta) * grad).detach().requires_grad_(True)
            params = new_params
        state = state.detach()
        total_loss += float(loss.detach().cpu().item())

    avg_loss = total_loss / max(1, len(schedule))
    updated = {k: v.detach().clone() for k, v in params.items()}
    return updated, float(avg_loss)


def validate_bptt(*, device: torch.device, plots_dir: Path, epochs: int) -> None:
    print("\n[TEST] BPTT: full/TBPTT vs reference autograd ...")
    set_seed(0)

    batch_size = 5
    input_size = 4
    hidden_size = 7
    output_size = 3
    time_steps = 6
    eta = 0.05

    inputs_np, targets_np, _labels_np, step_weights_np = _make_last_step_classification_batch(
        batch_size=batch_size,
        input_size=input_size,
        time_steps=time_steps,
        num_classes=output_size,
        seed=123,
    )
    step_weights_t = torch.as_tensor(step_weights_np, dtype=torch.float32, device=device)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=999)
    params_t = _torch_params_from_rnnparams(params_np, device=device)

    # Full BPTT (single step)
    bptt = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=None,
        time_normalization=True,
        seed=0,
        device=device,
    )
    bptt.step_weights = step_weights_t
    _copy_params_into_bptt(bptt, params_np)
    loss_model, _ = bptt.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    params_after_model = _extract_rnn_parameters_any(bptt)

    ref_after, loss_ref = _reference_full_bptt_sgd(
        torch.as_tensor(inputs_np, device=device),
        torch.as_tensor(targets_np, device=device),
        params_t,
        h_prev,
        eta=eta,
        loss_mode="ce",
        step_weights=step_weights_t,
        time_normalization=True,
    )
    params_after_ref = _rnnparams_from_torch_params(ref_after)
    diff = _param_max_abs_diff(params_after_model, params_after_ref)
    print(f"  full: diff={diff:.3e} loss(model)={loss_model:.6f} loss(ref)={loss_ref:.6f}")
    assert diff < 1e-6, f"Full BPTT mismatch: diff={diff}"
    assert abs(float(loss_model) - float(loss_ref)) < 1e-7, "Full BPTT loss mismatch."

    # TBPTT (single step)
    tbptt_steps = 2
    bptt_tb = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=tbptt_steps,
        time_normalization=True,
        seed=0,
        device=device,
    )
    bptt_tb.step_weights = step_weights_t
    _copy_params_into_bptt(bptt_tb, params_np)
    loss_model_tb, _ = bptt_tb.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    params_after_model_tb = _extract_rnn_parameters_any(bptt_tb)

    ref_after_tb, loss_ref_tb = _reference_tbptt_sgd(
        torch.as_tensor(inputs_np, device=device),
        torch.as_tensor(targets_np, device=device),
        params_t,
        h_prev,
        eta=eta,
        loss_mode="ce",
        step_weights=step_weights_t,
        tbptt_steps=tbptt_steps,
        time_normalization=True,
    )
    params_after_ref_tb = _rnnparams_from_torch_params(ref_after_tb)
    diff_tb = _param_max_abs_diff(params_after_model_tb, params_after_ref_tb)
    print(
        f"  tbptt({tbptt_steps}): diff={diff_tb:.3e} loss(model)={loss_model_tb:.6f} loss(ref)={loss_ref_tb:.6f}"
    )
    assert diff_tb < 1e-6, f"TBPTT mismatch: diff={diff_tb}"
    assert abs(float(loss_model_tb) - float(loss_ref_tb)) < 1e-7, "TBPTT loss mismatch."

    # Multi-epoch consistency
    diffs: List[float] = []
    bptt_loop = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=None,
        time_normalization=True,
        seed=0,
        device=device,
    )
    bptt_loop.step_weights = step_weights_t
    _copy_params_into_bptt(bptt_loop, params_np)
    ref_params = params_t
    for epoch in range(int(epochs)):
        _loss_m, _ = bptt_loop.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
        ref_params, _loss_r = _reference_full_bptt_sgd(
            torch.as_tensor(inputs_np, device=device),
            torch.as_tensor(targets_np, device=device),
            ref_params,
            h_prev,
            eta=eta,
            loss_mode="ce",
            step_weights=step_weights_t,
            time_normalization=True,
        )
        diff_e = _param_max_abs_diff(_extract_rnn_parameters_any(bptt_loop), _rnnparams_from_torch_params(ref_params))
        diffs.append(diff_e)
        assert diff_e < 1e-6, f"BPTT multi-epoch mismatch at epoch={epoch}: diff={diff_e}"

    _plot_lines(
        {"BPTT vs ref max|Δθ|": diffs},
        title="BPTT correctness (multi-epoch)",
        xlabel="Epoch",
        ylabel="max abs parameter diff",
        out_path=plots_dir / "bptt_vs_ref_param_diff.png",
    )


def validate_eprop(*, device: torch.device, plots_dir: Path, epochs: int) -> None:
    print("\n[TEST] E-Prop: decay_lambda=0 should match torch autograd(detach), multi-epoch ...")
    set_seed(1)

    batch_size = 6
    input_size = 5
    hidden_size = 8
    output_size = 4
    time_steps = 7
    eta = 0.05

    inputs_np, targets_np, _labels_np, step_weights_np = _make_last_step_classification_batch(
        batch_size=batch_size,
        input_size=input_size,
        time_steps=time_steps,
        num_classes=output_size,
        seed=321,
    )
    step_weights_t = torch.as_tensor(step_weights_np, dtype=torch.float32, device=device)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=111)

    eprop = StandardEPropRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        decay_lambda=0.0,
        feedback="symmetric",
        loss_mode="ce",
        max_grad_norm=0.0,
        device=device,
    )
    eprop.step_weights = step_weights_t
    _copy_params_into_eprop(eprop, params_np)

    ref_params = _torch_params_from_rnnparams(params_np, device=device)
    diffs: List[float] = []
    for epoch in range(int(epochs)):
        loss_m, _ = eprop.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
        ref_params, loss_r = _reference_online_detach_sgd(
            torch.as_tensor(inputs_np, device=device),
            torch.as_tensor(targets_np, device=device),
            ref_params,
            h_prev,
            eta=eta,
            loss_mode="ce",
            step_weights=step_weights_t,
        )
        params_after_model = _extract_rnn_parameters_any(eprop)
        params_after_ref = _rnnparams_from_torch_params(ref_params)
        diff = _param_max_abs_diff(params_after_model, params_after_ref)
        diffs.append(diff)
        if epoch in {0, int(epochs) - 1}:
            print(f"  epoch {epoch:02d}: diff={diff:.3e} loss(model)={loss_m:.6f} loss(ref)={loss_r:.6f}")
        assert diff < 2e-6, f"E-Prop vs torch(detach) mismatch at epoch={epoch}: diff={diff}"
        assert abs(float(loss_m) - float(loss_r)) < 1e-6, "E-Prop loss mismatch vs torch(detach)."

    _plot_lines(
        {"E-Prop vs torch(detach) max|Δθ|": diffs},
        title="E-Prop correctness (decay_lambda=0, multi-epoch)",
        xlabel="Epoch",
        ylabel="max abs parameter diff",
        out_path=plots_dir / "eprop_vs_detach_param_diff.png",
    )


def validate_local_rule(*, device: torch.device, plots_dir: Path, epochs: int) -> None:
    print("\n[TEST] Local Rule: lambda_cap=0 should match torch autograd(detach), multi-epoch ...")
    set_seed(2)

    batch_size = 6
    input_size = 5
    hidden_size = 8
    output_size = 4
    time_steps = 7
    eta = 0.05

    inputs_np, targets_np, _labels_np, step_weights_np = _make_last_step_classification_batch(
        batch_size=batch_size,
        input_size=input_size,
        time_steps=time_steps,
        num_classes=output_size,
        seed=777,
    )
    step_weights_t = torch.as_tensor(step_weights_np, dtype=torch.float32, device=device)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=222)

    local = TorchLocalRuleRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        lambda_window=50,
        loss_mode="ce",
        max_grad_norm=0.0,
        seed=0,
        device=device,
    )
    local.step_weights = step_weights_t
    local.lambda_cap = 0.0  # freeze lambda -> exact local gradient regime
    _copy_params_into_local(local, params_np)

    ref_params = _torch_params_from_rnnparams(params_np, device=device)
    diffs: List[float] = []
    for epoch in range(int(epochs)):
        loss_m, _ = local.run_one_cycle_and_update_directly(inputs_np, targets_np, h_prev.cpu().numpy())
        ref_params, loss_r = _reference_online_detach_sgd(
            torch.as_tensor(inputs_np, device=device),
            torch.as_tensor(targets_np, device=device),
            ref_params,
            h_prev,
            eta=eta,
            loss_mode="ce",
            step_weights=step_weights_t,
        )
        params_after_model = _extract_rnn_parameters_any(local)
        params_after_ref = _rnnparams_from_torch_params(ref_params)
        diff = _param_max_abs_diff(params_after_model, params_after_ref)
        diffs.append(diff)
        if epoch in {0, int(epochs) - 1}:
            print(f"  epoch {epoch:02d}: diff={diff:.3e} loss(model)={loss_m:.6f} loss(ref)={loss_r:.6f}")
        assert diff < 2e-6, f"Local Rule vs torch(detach) mismatch at epoch={epoch}: diff={diff}"

    _plot_lines(
        {"LocalRule(λ_cap=0) vs torch(detach) max|Δθ|": diffs},
        title="Local Rule correctness (λ_cap=0, multi-epoch)",
        xlabel="Epoch",
        ylabel="max abs parameter diff",
        out_path=plots_dir / "localrule_vs_detach_param_diff.png",
    )


def validate_fptt(*, device: torch.device, plots_dir: Path, epochs: int) -> None:
    print("\n[TEST] Strict FPTT: limiting cases + oracle mixing/update ...")
    set_seed(3)

    batch_size = 5
    input_size = 4
    hidden_size = 7
    output_size = 3
    time_steps = 9
    eta = 0.05

    inputs_np, targets_np, labels_np, _step_weights_np = _make_last_step_classification_batch(
        batch_size=batch_size,
        input_size=input_size,
        time_steps=time_steps,
        num_classes=output_size,
        seed=2024,
    )
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)
    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=333)

    # (A) parts=1, oracle/reg off -> should reduce to full BPTT on final step.
    step_weights = np.zeros((time_steps,), dtype=np.float32)
    step_weights[-1] = 1.0
    step_weights_t = torch.as_tensor(step_weights, dtype=torch.float32, device=device)

    bptt = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=None,
        time_normalization=True,
        seed=0,
        device=device,
    )
    bptt.step_weights = step_weights_t
    _copy_params_into_bptt(bptt, params_np)

    fptt_parts1 = StrictFPTTClassifier(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        parts=1,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,
        lmbda=0.0,
        oracle_momentum=1.0,
        warmup_epochs=0,
        oracle_id="validate_parts1",
        label_mode="last",
        use_oracle=False,
        optimizer_cls=torch.optim.SGD,
        device=device,
    )
    fptt_parts1.W_xh = params_np.W_xh
    fptt_parts1.W_hh = params_np.W_hh
    fptt_parts1.b_h = params_np.b_h
    fptt_parts1.W_hy = params_np.W_hy
    fptt_parts1.b_y = params_np.b_y
    fptt_parts1.reset_state_buffers()
    fptt_parts1.set_epoch(0)

    loss_bptt, _ = bptt.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    loss_fptt, _ = fptt_parts1.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    diff_parts1 = _param_max_abs_diff(_extract_rnn_parameters_any(bptt), _extract_rnn_parameters_any(fptt_parts1))
    print(f"  parts=1: diff={diff_parts1:.3e} loss(BPTT)={loss_bptt:.6f} loss(FPTT)={loss_fptt:.6f}")
    assert diff_parts1 < 1e-6, f"FPTT(parts=1) should match BPTT: diff={diff_parts1}"
    assert abs(float(loss_bptt) - float(loss_fptt)) < 1e-6, "FPTT(parts=1) loss mismatch vs BPTT."

    # (B) parts>1, oracle/reg off -> should match reference chunked BPTT update.
    parts = 3
    fptt_chunk = StrictFPTTClassifier(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        parts=parts,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,
        lmbda=0.0,
        oracle_momentum=1.0,
        warmup_epochs=0,
        oracle_id="validate_chunks",
        label_mode="last",
        use_oracle=False,
        optimizer_cls=torch.optim.SGD,
        device=device,
    )
    fptt_chunk.W_xh = params_np.W_xh
    fptt_chunk.W_hh = params_np.W_hh
    fptt_chunk.b_h = params_np.b_h
    fptt_chunk.W_hy = params_np.W_hy
    fptt_chunk.b_y = params_np.b_y
    fptt_chunk.reset_state_buffers()

    ref_params = {
        "W_xh": torch.as_tensor(params_np.W_xh, dtype=torch.float32, device=device),
        "W_hh": torch.as_tensor(params_np.W_hh, dtype=torch.float32, device=device),
        "b_h": torch.as_tensor(params_np.b_h.reshape(-1), dtype=torch.float32, device=device),
        "W_hy": torch.as_tensor(params_np.W_hy, dtype=torch.float32, device=device),
        "b_y": torch.as_tensor(params_np.b_y.reshape(-1), dtype=torch.float32, device=device),
    }

    diffs: List[float] = []
    for epoch in range(int(epochs)):
        fptt_chunk.set_epoch(epoch)
        _loss_m, _ = fptt_chunk.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
        ref_params, _loss_r = _reference_fptt_classifier_chunk_sgd(
            torch.as_tensor(inputs_np, device=device),
            torch.as_tensor(targets_np, device=device),
            ref_params,
            eta=eta,
            parts=parts,
        )
        params_after_model = _extract_rnn_parameters_any(fptt_chunk)
        params_after_ref = _rnnparams_from_torch_params(
            {
                "W_xh": ref_params["W_xh"],
                "W_hh": ref_params["W_hh"],
                "b_h": ref_params["b_h"].reshape(-1, 1),
                "W_hy": ref_params["W_hy"],
                "b_y": ref_params["b_y"].reshape(-1, 1),
            }
        )
        diff = _param_max_abs_diff(params_after_model, params_after_ref)
        diffs.append(diff)
        assert diff < 2e-6, f"FPTT chunk update mismatch at epoch={epoch}: diff={diff}"

    _plot_lines(
        {"StrictFPTT(parts>1, no-oracle/reg) vs ref max|Δθ|": diffs},
        title="StrictFPTT chunk correctness (multi-epoch)",
        xlabel="Epoch",
        ylabel="max abs parameter diff",
        out_path=plots_dir / "fptt_chunk_vs_ref_param_diff.png",
    )

    # (C) Oracle mixing and buffer update: deterministic logits, compare expected loss and buffer states.
    print("  oracle mixing/update: deterministic check ...")
    OracleBufferStore.reset("validate_oracle")
    oracle_key = "validate_oracle"
    buf: ClassOracleBuffer = OracleBufferStore.get(
        oracle_key, num_classes=output_size, max_parts=parts, momentum=1.0
    )
    buf.ensure(parts)

    custom = np.full_like(buf._storage, 1.0 / float(output_size), dtype=np.float32)
    custom[1, 0] = np.array([0.7, 0.2, 0.1], dtype=np.float32)
    custom[2, 1] = np.array([0.05, 0.1, 0.85], dtype=np.float32)
    buf._storage[:] = custom

    fptt_oracle = StrictFPTTClassifier(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=0.0,  # freeze weights -> isolate oracle logic
        parts=parts,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,
        lmbda=0.0,
        oracle_momentum=1.0,
        warmup_epochs=0,
        oracle_id=oracle_key,
        label_mode="last",
        use_oracle=True,
        optimizer_cls=torch.optim.SGD,
        device=device,
    )

    fptt_oracle.W_xh = params_np.W_xh
    fptt_oracle.W_hh = params_np.W_hh
    fptt_oracle.b_h = params_np.b_h
    fptt_oracle.W_hy = np.zeros_like(params_np.W_hy)
    fixed_bias = np.array([[2.0], [-1.0], [-1.5]], dtype=np.float32)
    fptt_oracle.b_y = fixed_bias
    fptt_oracle.reset_state_buffers()
    fptt_oracle.set_epoch(0)

    inputs_t = torch.as_tensor(inputs_np, dtype=torch.float32, device=device)
    targets_t = torch.as_tensor(targets_np, dtype=torch.float32, device=device)
    schedule = build_chunk_schedule(time_steps, parts)
    total_chunks = len(schedule)
    oracle_cutoff = max(0, min(parts - 1, total_chunks - 1))

    logits = torch.as_tensor(fixed_bias.reshape(-1), device=device).view(1, -1).repeat(batch_size, 1)
    log_probs = torch.log_softmax(logits, dim=1)
    probs = torch.softmax(logits, dim=1).detach().cpu().numpy().T  # (C, B) constant
    preds = np.argmax(probs, axis=0)
    label_idx = labels_np.astype(np.int64)
    onehot = targets_t[:, :, -1]

    expected_losses: List[float] = []
    expected_storage = custom.copy()
    for chunk_idx in range(total_chunks):
        alpha = float(chunk_idx + 1) / float(max(1, total_chunks))
        oracle_active = chunk_idx < oracle_cutoff
        oracle_weight = (1.0 - alpha) if oracle_active else 0.0
        if oracle_active:
            surrogate = expected_storage[label_idx, chunk_idx]  # (B, C)
            surrogate_t = torch.as_tensor(surrogate, dtype=torch.float32, device=device)
            mix_target = alpha * onehot + oracle_weight * surrogate_t
        else:
            mix_target = onehot
        loss_ce = torch.sum(-mix_target * log_probs, dim=1).mean()
        expected_losses.append(float(loss_ce.detach().cpu().item()))

        if oracle_active:
            filled = np.zeros(output_size, dtype=bool)
            for col, (y, y_hat) in enumerate(zip(label_idx, preds)):
                if y < 0 or y >= output_size:
                    continue
                if filled[y] or (y_hat == y):
                    continue
                expected_storage[y, chunk_idx] = probs[:, col]
                filled[y] = True

    loss_actual, _ = fptt_oracle.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    expected_avg = float(np.mean(expected_losses))
    print(f"    expected avg loss={expected_avg:.6f} | model avg loss={loss_actual:.6f}")
    assert abs(float(loss_actual) - expected_avg) < 1e-6, "FPTT oracle loss mismatch."
    storage_diff = float(np.max(np.abs(buf._storage - expected_storage)))
    print(f"    oracle buffer max abs diff = {storage_diff:.3e}")
    assert storage_diff < 1e-6, "FPTT oracle buffer update mismatch."

    _plot_lines(
        {"expected chunk loss": expected_losses},
        title="FPTT oracle: per-chunk CE loss (deterministic logits)",
        xlabel="Chunk",
        ylabel="loss",
        out_path=plots_dir / "fptt_oracle_chunk_losses.png",
    )


def validate_toy_training(*, device: torch.device, plots_dir: Path, epochs: int) -> None:
    """
    Not a proof test: provides a sanity-check training curve for all methods on a learnable toy task.
    """
    print("\n[TEST] Toy task sanity training + plots (loss/acc curves) ...")
    set_seed(4)

    rng = np.random.default_rng(2026)
    num_samples = 256
    batch_size = 32
    input_size = 6
    hidden_size = 32
    output_size = 2
    time_steps = 12
    eta = 0.02
    parts = 4

    inputs = rng.standard_normal((num_samples, input_size, time_steps)).astype(np.float32)
    score = inputs[:, 0, :].sum(axis=1)
    labels = (score > 0).astype(np.int64)
    targets = np.zeros((num_samples, output_size, time_steps), dtype=np.float32)
    targets[np.arange(num_samples), labels, -1] = 1.0

    step_weights = np.zeros((time_steps,), dtype=np.float32)
    step_weights[-1] = 1.0
    step_weights_t = torch.as_tensor(step_weights, dtype=torch.float32, device=device)

    params0 = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=42)

    # Precompute a deterministic minibatch schedule shared across methods.
    sched_rng = np.random.default_rng(9999)
    schedule: List[List[np.ndarray]] = []
    for _ in range(int(epochs)):
        perm = sched_rng.permutation(num_samples)
        batches = [perm[i : i + batch_size] for i in range(0, num_samples, batch_size)]
        schedule.append(batches)

    def forward_last_logits(model: object, x_b: np.ndarray) -> torch.Tensor:
        b = int(x_b.shape[0])
        h0 = torch.zeros((hidden_size, b), dtype=torch.float32, device=device)
        outputs, _ = model.forward_cycle(x_b, h0.cpu().numpy())
        last = torch.as_tensor(outputs[-1], device=device, dtype=torch.float32)
        return last  # (C, B)

    def eval_acc(model: object) -> float:
        correct = 0
        total = 0
        for batch_idx in schedule[-1]:
            logits = forward_last_logits(model, inputs[batch_idx])
            pred = torch.argmax(logits, dim=0).cpu().numpy()
            correct += int(np.sum(pred == labels[batch_idx]))
            total += int(batch_idx.shape[0])
        return float(correct / max(1, total))

    def train_epoch(name: str, model: object, train_step, epoch: int) -> float:
        losses: List[float] = []
        for batch_idx in schedule[epoch]:
            x_b = inputs[batch_idx]
            y_b = targets[batch_idx]
            b = int(x_b.shape[0])
            h0 = torch.zeros((hidden_size, b), dtype=torch.float32, device=device)
            loss, _ = train_step(model, x_b, y_b, h0)
            losses.append(float(loss))
        avg_loss = float(np.mean(losses))
        if epoch in {0, int(epochs) - 1}:
            print(f"  {name}: epoch {epoch:02d} loss={avg_loss:.4f}")
        return avg_loss

    # Build models (start from identical weights).
    bptt = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=None,
        time_normalization=True,
        seed=0,
        device=device,
    )
    bptt.step_weights = step_weights_t
    _copy_params_into_bptt(bptt, params0)

    # Torch detach baseline: TBPTT with chunk length 1 detaches state each step.
    detach = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=1,
        time_normalization=True,
        seed=0,
        device=device,
    )
    detach.step_weights = step_weights_t
    _copy_params_into_bptt(detach, params0)

    eprop = StandardEPropRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        decay_lambda=0.95,
        feedback="symmetric",
        loss_mode="ce",
        max_grad_norm=0.0,
        device=device,
    )
    eprop.step_weights = step_weights_t
    _copy_params_into_eprop(eprop, params0)

    eprop_lambda0 = StandardEPropRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        decay_lambda=0.0,
        feedback="symmetric",
        loss_mode="ce",
        max_grad_norm=0.0,
        device=device,
    )
    eprop_lambda0.step_weights = step_weights_t
    _copy_params_into_eprop(eprop_lambda0, params0)

    local = TorchLocalRuleRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        lambda_window=50,
        loss_mode="ce",
        max_grad_norm=0.0,
        seed=0,
        device=device,
    )
    local.step_weights = step_weights_t
    _copy_params_into_local(local, params0)

    fptt = StrictFPTTClassifier(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        parts=parts,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,
        lmbda=0.0,
        oracle_momentum=1.0,
        warmup_epochs=0,
        oracle_id="validate_toy",
        label_mode="last",
        use_oracle=False,
        optimizer_cls=torch.optim.SGD,
        device=device,
    )
    fptt.W_xh = params0.W_xh
    fptt.W_hh = params0.W_hh
    fptt.b_h = params0.b_h
    fptt.W_hy = params0.W_hy
    fptt.b_y = params0.b_y
    fptt.reset_state_buffers()

    def step_bptt(model: TorchBPTTRNN, x_b: np.ndarray, y_b: np.ndarray, h0: torch.Tensor):
        return model.train_batch(x_b, y_b, h0.cpu().numpy())

    def step_eprop(model: StandardEPropRNN, x_b: np.ndarray, y_b: np.ndarray, h0: torch.Tensor):
        return model.train_batch(x_b, y_b, h0.cpu().numpy())

    def step_local(model: TorchLocalRuleRNN, x_b: np.ndarray, y_b: np.ndarray, h0: torch.Tensor):
        return model.run_one_cycle_and_update_directly(x_b, y_b, h0.cpu().numpy())

    def step_fptt(model: StrictFPTTClassifier, x_b: np.ndarray, y_b: np.ndarray, h0: torch.Tensor):
        model.set_epoch(0)
        return model.train_batch(x_b, y_b, h0.cpu().numpy())

    curves_loss: Dict[str, List[float]] = {
        "BPTT": [],
        "Torch(detach)": [],
        "E-Prop(λ=0.95)": [],
        "E-Prop(λ=0.0)": [],
        "Local Rule": [],
        f"StrictFPTT(parts={parts})": [],
    }
    curves_acc: Dict[str, List[float]] = {k: [] for k in curves_loss.keys()}

    for epoch in range(int(epochs)):
        curves_loss["BPTT"].append(train_epoch("BPTT", bptt, step_bptt, epoch))
        curves_loss["Torch(detach)"].append(train_epoch("Torch(detach)", detach, step_bptt, epoch))
        curves_loss["E-Prop(λ=0.95)"].append(train_epoch("E-Prop", eprop, step_eprop, epoch))
        curves_loss["E-Prop(λ=0.0)"].append(train_epoch("E-Prop λ=0", eprop_lambda0, step_eprop, epoch))
        curves_loss["Local Rule"].append(train_epoch("Local Rule", local, step_local, epoch))
        curves_loss[f"StrictFPTT(parts={parts})"].append(train_epoch("StrictFPTT", fptt, step_fptt, epoch))

        curves_acc["BPTT"].append(eval_acc(bptt))
        curves_acc["Torch(detach)"].append(eval_acc(detach))
        curves_acc["E-Prop(λ=0.95)"].append(eval_acc(eprop))
        curves_acc["E-Prop(λ=0.0)"].append(eval_acc(eprop_lambda0))
        curves_acc["Local Rule"].append(eval_acc(local))
        curves_acc[f"StrictFPTT(parts={parts})"].append(eval_acc(fptt))

    _plot_lines(
        curves_loss,
        title="Toy task: training loss (last-step classification)",
        xlabel="Epoch",
        ylabel="loss",
        out_path=plots_dir / "toy_training_loss.png",
    )
    _plot_lines(
        curves_acc,
        title="Toy task: training accuracy (last-step classification)",
        xlabel="Epoch",
        ylabel="accuracy",
        out_path=plots_dir / "toy_training_accuracy.png",
    )


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Comprehensive validation for Compare_RNN implementations (BPTT/E-Prop/LocalRule/FPTT)."
    )
    parser.add_argument("--epochs", type=int, default=12, help="epochs for multi-epoch equivalence checks")
    parser.add_argument("--toy-epochs", type=int, default=20, help="epochs for the toy-task sanity curves")
    parser.add_argument("--skip-toy", action="store_true", help="skip the toy-task sanity training run")
    parser.add_argument("--device", type=str, default="cpu", help="device (cpu/cuda)")
    parser.add_argument(
        "--out-root",
        type=str,
        default=None,
        help="output root folder (default: Compare_RNN/validation_tests/plots/<timestamp>/)",
    )
    args = parser.parse_args()

    device = torch.device(args.device)
    artifacts = _prepare_artifacts(out_root=Path(args.out_root) if args.out_root else None)

    validate_bptt(device=device, plots_dir=artifacts.plots_dir, epochs=args.epochs)
    validate_eprop(device=device, plots_dir=artifacts.plots_dir, epochs=args.epochs)
    validate_local_rule(device=device, plots_dir=artifacts.plots_dir, epochs=args.epochs)
    validate_fptt(device=device, plots_dir=artifacts.plots_dir, epochs=max(6, args.epochs))
    if not args.skip_toy:
        validate_toy_training(device=device, plots_dir=artifacts.plots_dir, epochs=args.toy_epochs)

    print("\nOK")
    print(f"[PLOTS] {artifacts.plots_dir}")


if __name__ == "__main__":
    main()
