from __future__ import annotations

import sys
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.strict_fptt import StrictFPTTClassifier, build_chunk_schedule
from methods.shared_rnn_utils import initialize_rnn_parameters, apply_rnn_parameters, extract_rnn_parameters


def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)


def reference_chunk_updates(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    schedule: List[Tuple[int, int]],
    *,
    eta: float,
) -> List[Dict[str, torch.Tensor]]:
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    state = torch.zeros((inputs.size(0), params["W_hh"].size(0)), dtype=torch.float32)
    inputs_t = inputs.permute(0, 2, 1)
    labels = targets[:, :, -1]

    snapshots: List[Dict[str, torch.Tensor]] = []

    for start, end in schedule:
        chunk_inputs = inputs_t[:, start:end, :]
        for t in range(chunk_inputs.size(1)):
            state = torch.tanh(
                state @ params["W_hh"].T + chunk_inputs[:, t, :] @ params["W_xh"].T + params["b_h"]
            )

        logits = state @ params["W_hy"].T + params["b_y"]
        log_probs = torch.log_softmax(logits, dim=1)
        loss = torch.sum(-labels * log_probs, dim=1).mean()
        loss.backward()

        with torch.no_grad():
            for name in params.keys():
                params[name] = (params[name] - eta * params[name].grad).detach().requires_grad_(True)
        state = state.detach()
        snapshot = {k: v.detach().clone() for k, v in params.items()}
        snapshots.append(snapshot)

    return snapshots


def strict_fptt_updates(
    model: StrictFPTTClassifier,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    schedule: List[Tuple[int, int]],
) -> List[Dict[str, torch.Tensor]]:
    inputs_t = inputs.permute(0, 2, 1)
    labels = targets[:, :, -1]
    label_indices = torch.argmax(labels, dim=1).cpu().numpy()

    state = torch.zeros((inputs.size(0), model.hidden_size), dtype=torch.float32)
    snapshots: List[Dict[str, torch.Tensor]] = []

    total_chunks = len(schedule)
    oracle_cutoff = max(0, min(model.parts - 1, total_chunks - 1))

    for chunk_idx, (start, end) in enumerate(schedule):
        chunk_inputs = inputs_t[:, start:end, :]
        state, _loss = model._run_classification_chunk(
            chunk_inputs,
            state,
            labels,
            label_indices,
            chunk_idx,
            total_chunks,
            oracle_cutoff,
            warmup=False,
            oracle=None,
        )
        snapshots.append(extract_rnn_parameters(model).__dict__)
    return snapshots


def max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(torch.max(torch.abs(a - b)).item())


def compare_snapshots(
    ref: List[Dict[str, torch.Tensor]], strict: List[Dict[str, torch.Tensor]], tol: float
) -> None:
    if len(ref) != len(strict):
        raise AssertionError("Snapshot length mismatch.")
    for idx, (r, s) in enumerate(zip(ref, strict)):
        for name in r.keys():
            ref_val = r[name]
            strict_val = torch.as_tensor(s[name])
            if name in {"b_h", "b_y"}:
                ref_val = ref_val.reshape(-1)
                strict_val = strict_val.reshape(-1)
            diff = max_abs_diff(ref_val, strict_val)
            print(f"chunk {idx} {name} max abs diff = {diff:.6e}")
            if diff > tol:
                raise AssertionError(f"Mismatch at chunk {idx} for {name}: {diff}")


def main() -> None:
    set_seed(11)

    batch_size = 4
    input_size = 3
    hidden_size = 6
    output_size = 3
    time_steps = 7
    parts = 3
    eta = 0.1

    params_np = initialize_rnn_parameters(
        input_size, hidden_size, output_size, gain=1.0, seed=321
    )

    rng = np.random.default_rng(1)
    inputs_np = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    labels = rng.integers(low=0, high=output_size, size=(batch_size,))
    targets_np = np.zeros((batch_size, output_size, time_steps), dtype=np.float32)
    targets_np[np.arange(batch_size), labels, -1] = 1.0

    inputs = torch.from_numpy(inputs_np)
    targets = torch.from_numpy(targets_np)

    schedule = build_chunk_schedule(time_steps, parts)
    params_init = {
        "W_xh": torch.from_numpy(params_np.W_xh),
        "W_hh": torch.from_numpy(params_np.W_hh),
        "b_h": torch.from_numpy(params_np.b_h.reshape(-1)),
        "W_hy": torch.from_numpy(params_np.W_hy),
        "b_y": torch.from_numpy(params_np.b_y.reshape(-1)),
    }

    ref_snapshots = reference_chunk_updates(
        inputs, targets, params_init, schedule, eta=eta
    )

    model = StrictFPTTClassifier(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        parts=parts,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,  # rho=1 and lmbda=0 disable the regularizer term for this check
        lmbda=0.0,
        oracle_momentum=1.0,
        warmup_epochs=0,
        use_oracle=False,
        label_mode="last",
        optimizer_cls=torch.optim.SGD,
        device="cpu",
    )
    apply_rnn_parameters(model, params_np)

    strict_snapshots = strict_fptt_updates(model, inputs, targets, schedule)
    print("Comparing StrictFPTT updates vs reference chunked BPTT.")
    compare_snapshots(ref_snapshots, strict_snapshots, tol=1e-5)
    print("OK")


if __name__ == "__main__":
    main()
