from __future__ import annotations

import sys
from pathlib import Path
from typing import List, Tuple

import numpy as np
import torch

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.standard_eprop import StandardEPropRNN
from methods.strict_fptt import ClassOracleBuffer, FPTTRegularizer, StrictFPTTClassifier, build_chunk_schedule
from methods.shared_rnn_utils import initialize_rnn_parameters, extract_rnn_parameters, apply_rnn_parameters


def assert_allclose(name: str, a: np.ndarray, b: np.ndarray, tol: float = 1e-6) -> None:
    diff = np.max(np.abs(a - b))
    print(f"{name} max abs diff = {diff:.6e}")
    if diff > tol:
        raise AssertionError(f"{name} mismatch: {diff}")


def check_chunk_schedule() -> None:
    def reference_schedule(time_steps: int, parts: int) -> List[Tuple[int, int]]:
        time_steps = max(1, int(time_steps))
        parts = max(1, int(parts))
        step = max(1, time_steps // parts)
        total_parts = parts + (1 if parts * step < time_steps else 0)
        schedule = []
        for idx in range(total_parts):
            start = idx * step
            if start >= time_steps:
                break
            end = min(time_steps, start + step)
            schedule.append((start, end))
        if schedule and schedule[-1][1] < time_steps:
            schedule.append((schedule[-1][1], time_steps))
        return schedule

    time_steps = 11
    parts = 4
    ref = reference_schedule(time_steps, parts)
    test = build_chunk_schedule(time_steps, parts)
    print(f"chunk schedule = {test}")
    if ref != test:
        raise AssertionError("Chunk schedule mismatch.")


def check_oracle_buffer_updates() -> None:
    num_classes = 3
    buf = ClassOracleBuffer(num_classes=num_classes, max_parts=2, momentum=1.0)

    labels = np.array([0, 1, 2, 1], dtype=np.int64)
    probs = np.array(
        [
            [0.1, 0.8, 0.2, 0.2],
            [0.7, 0.1, 0.7, 0.3],
            [0.2, 0.1, 0.1, 0.5],
        ],
        dtype=np.float32,
    )
    preds = np.array([1, 1, 2, 0], dtype=np.int64)

    buf.update(labels, idx=0, probs=probs, preds=preds)

    expected = np.full((num_classes, 2, num_classes), 1.0 / num_classes, dtype=np.float32)
    expected[0, 0] = probs[:, 0]
    expected[1, 0] = probs[:, 3]

    assert_allclose("oracle buffer", buf._storage, expected)


def check_regularizer_loss_and_step() -> None:
    param = torch.nn.Parameter(torch.tensor([1.0, -2.0], dtype=torch.float32))
    reg = FPTTRegularizer([("w", param)], alpha=0.2, beta=0.5, rho=0.3, lmbda=1.5)

    state = reg._state["w"]
    state["sm"].copy_(torch.tensor([0.5, -1.0], dtype=torch.float32))
    state["lm"].copy_(torch.tensor([0.1, -0.2], dtype=torch.float32))

    manual_loss = (0.3 - 1.0) * torch.sum(param * state["lm"])
    manual_loss = manual_loss + 1.5 * 0.5 * 0.2 * torch.sum((param - state["sm"]) ** 2)
    loss = reg.loss()
    diff = float(torch.max(torch.abs(loss - manual_loss)).item())
    print(f"regularizer loss diff = {diff:.6e}")
    if diff > 1e-6:
        raise AssertionError("Regularizer loss mismatch.")

    sm = state["sm"].clone()
    lm = state["lm"].clone()
    alpha = 0.2
    beta = 0.5
    delta = param.detach() - sm
    lm_expected = lm - alpha * delta
    sm_expected = (1.0 - beta) * sm + beta * param.detach() - (beta / alpha) * lm_expected

    reg.step()
    diff_lm = float(torch.max(torch.abs(state["lm"] - lm_expected)).item())
    diff_sm = float(torch.max(torch.abs(state["sm"] - sm_expected)).item())
    print(f"regularizer lm diff = {diff_lm:.6e}")
    print(f"regularizer sm diff = {diff_sm:.6e}")
    if diff_lm > 1e-6 or diff_sm > 1e-6:
        raise AssertionError("Regularizer step mismatch.")


def check_shared_initialization() -> None:
    params_np = initialize_rnn_parameters(3, 5, 2, gain=1.0, seed=99)

    eprop = StandardEPropRNN(
        input_size=3,
        hidden_size=5,
        output_size=2,
        eta=0.01,
        decay_lambda=0.9,
        feedback="symmetric",
        loss_mode="mse",
        max_grad_norm=0.0,
        device="cpu",
    )
    eprop.W_xh = torch.from_numpy(params_np.W_xh)
    eprop.W_hh = torch.from_numpy(params_np.W_hh)
    eprop.b_h = torch.from_numpy(params_np.b_h)
    eprop.W_hy = torch.from_numpy(params_np.W_hy)
    eprop.b_y = torch.from_numpy(params_np.b_y)

    fptt = StrictFPTTClassifier(
        input_size=3,
        hidden_size=5,
        output_size=2,
        eta=0.01,
        parts=2,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=0.0,
        lmbda=0.0,
        use_oracle=False,
        warmup_epochs=0,
        label_mode="last",
        optimizer_cls=torch.optim.SGD,
        device="cpu",
    )
    apply_rnn_parameters(fptt, params_np)

    params_e = extract_rnn_parameters(eprop)
    params_f = extract_rnn_parameters(fptt)

    assert_allclose("W_xh init", params_e.W_xh, params_f.W_xh)
    assert_allclose("W_hh init", params_e.W_hh, params_f.W_hh)
    assert_allclose("b_h init", params_e.b_h, params_f.b_h)
    assert_allclose("W_hy init", params_e.W_hy, params_f.W_hy)
    assert_allclose("b_y init", params_e.b_y, params_f.b_y)


def main() -> None:
    check_chunk_schedule()
    check_oracle_buffer_updates()
    check_regularizer_loss_and_step()
    check_shared_initialization()
    print("OK")


if __name__ == "__main__":
    main()
