from __future__ import annotations

import sys
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.shared_rnn_utils import initialize_rnn_parameters, apply_rnn_parameters, extract_rnn_parameters
from methods.strict_fptt import StrictFPTTClassifier, StrictFPTTRegressor, build_chunk_schedule


def set_seed(seed: int) -> None:
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))


def reference_seq2seq_chunk_updates(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    schedule: List[Tuple[int, int]],
    *,
    eta: float,
) -> List[Dict[str, torch.Tensor]]:
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    batch_size = inputs.size(0)
    hidden_size = params["W_hh"].size(0)
    inputs_t = inputs.permute(0, 2, 1)
    targets_t = targets.permute(0, 2, 1)

    state = torch.zeros((batch_size, hidden_size), dtype=torch.float32)
    total_chunks = max(1, len(schedule))
    snapshots: List[Dict[str, torch.Tensor]] = []

    for chunk_idx, (start, end) in enumerate(schedule):
        chunk_inputs = inputs_t[:, start:end, :]
        chunk_targets = targets_t[:, start:end, :]
        if chunk_inputs.size(1) == 0:
            continue

        logits_steps: List[torch.Tensor] = []
        for t in range(chunk_inputs.size(1)):
            state = torch.tanh(
                torch.matmul(state, params["W_hh"].t())
                + torch.matmul(chunk_inputs[:, t, :], params["W_xh"].t())
                + params["b_h"]
            )
            logits_steps.append(torch.matmul(state, params["W_hy"].t()) + params["b_y"])

        logits = torch.stack(logits_steps, dim=1)  # (B, T, C)
        log_probs = torch.log_softmax(logits, dim=2)
        loss_ce = torch.sum(-chunk_targets * log_probs, dim=2).mean()
        chunk_weight = float(chunk_idx + 1) / float(total_chunks)
        loss = chunk_weight * loss_ce

        loss.backward()
        with torch.no_grad():
            new_params: Dict[str, torch.Tensor] = {}
            for name, value in params.items():
                new_params[name] = (value - float(eta) * value.grad).detach().requires_grad_(True)
            params = new_params
        state = state.detach()
        snapshots.append({k: v.detach().clone() for k, v in params.items()})

    return snapshots


def reference_regression_chunk_updates(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    schedule: List[Tuple[int, int]],
    *,
    eta: float,
) -> List[Dict[str, torch.Tensor]]:
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    batch_size = inputs.size(0)
    hidden_size = params["W_hh"].size(0)
    inputs_t = inputs.permute(0, 2, 1)
    targets_t = targets.permute(0, 2, 1)

    state = torch.zeros((batch_size, hidden_size), dtype=torch.float32)
    total_chunks = max(1, len(schedule))
    snapshots: List[Dict[str, torch.Tensor]] = []

    for chunk_idx, (start, end) in enumerate(schedule):
        chunk_inputs = inputs_t[:, start:end, :]
        chunk_targets = targets_t[:, start:end, :]
        if chunk_inputs.size(1) == 0:
            continue

        losses: List[torch.Tensor] = []
        chunk_weight = float(chunk_idx + 1) / float(total_chunks)
        for t in range(chunk_inputs.size(1)):
            state = torch.tanh(
                torch.matmul(state, params["W_hh"].t())
                + torch.matmul(chunk_inputs[:, t, :], params["W_xh"].t())
                + params["b_h"]
            )
            preds = torch.matmul(state, params["W_hy"].t()) + params["b_y"]
            error = preds - chunk_targets[:, t, :]
            mse = 0.5 * torch.mean(error**2)
            losses.append(chunk_weight * mse)

        loss_core = torch.stack(losses).mean()
        loss_core.backward()
        with torch.no_grad():
            new_params: Dict[str, torch.Tensor] = {}
            for name, value in params.items():
                new_params[name] = (value - float(eta) * value.grad).detach().requires_grad_(True)
            params = new_params
        state = state.detach()
        snapshots.append({k: v.detach().clone() for k, v in params.items()})

    return snapshots


def strict_seq2seq_updates(
    model: StrictFPTTClassifier,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    schedule: List[Tuple[int, int]],
) -> List[Dict[str, np.ndarray]]:
    inputs_t = inputs.permute(0, 2, 1)
    targets_t = targets.permute(0, 2, 1)
    state = torch.zeros((inputs.size(0), model.hidden_size), dtype=torch.float32)
    total_chunks = max(1, len(schedule))
    snapshots: List[Dict[str, np.ndarray]] = []

    for chunk_idx, (start, end) in enumerate(schedule):
        chunk_inputs = inputs_t[:, start:end, :]
        chunk_targets = targets_t[:, start:end, :]
        if chunk_inputs.size(1) == 0:
            continue
        chunk_weight = float(chunk_idx + 1) / float(total_chunks)
        state, _loss = model._run_sequence_chunk(chunk_inputs, chunk_targets, state, chunk_weight)
        snapshots.append(extract_rnn_parameters(model).__dict__)
    return snapshots


def strict_regression_updates(
    model: StrictFPTTRegressor,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    schedule: List[Tuple[int, int]],
) -> List[Dict[str, np.ndarray]]:
    inputs_t = inputs.permute(0, 2, 1)
    targets_t = targets.permute(0, 2, 1)
    state = torch.zeros((inputs.size(0), model.hidden_size), dtype=torch.float32)
    total_chunks = max(1, len(schedule))
    snapshots: List[Dict[str, np.ndarray]] = []

    for chunk_idx, (start, end) in enumerate(schedule):
        chunk_inputs = inputs_t[:, start:end, :]
        chunk_targets = targets_t[:, start:end, :]
        if chunk_inputs.size(1) == 0:
            continue
        chunk_weight = float(chunk_idx + 1) / float(total_chunks)
        state, _loss = model._run_regression_chunk(chunk_inputs, chunk_targets, state, chunk_weight)
        snapshots.append(extract_rnn_parameters(model).__dict__)
    return snapshots


def max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(torch.max(torch.abs(a - b)).item())


def compare_snapshots(
    ref: List[Dict[str, torch.Tensor]], strict: List[Dict[str, np.ndarray]], tol: float
) -> None:
    if len(ref) != len(strict):
        raise AssertionError("Snapshot length mismatch.")
    for idx, (r, s) in enumerate(zip(ref, strict)):
        for name in r.keys():
            ref_val = r[name]
            strict_val = torch.as_tensor(s[name])
            if name in {"b_h", "b_y"}:
                ref_val = ref_val.reshape(-1)
                strict_val = strict_val.reshape(-1)
            diff = max_abs_diff(ref_val, strict_val)
            print(f"chunk {idx} {name} max abs diff = {diff:.6e}")
            if diff > tol:
                raise AssertionError(f"Mismatch at chunk {idx} for {name}: {diff}")


def test_fptt_seq2seq() -> None:
    print("[TEST] StrictFPTT seq2seq (label_mode=all) vs reference chunked update")
    set_seed(7)

    batch_size = 3
    input_size = 4
    hidden_size = 5
    output_size = 3
    time_steps = 6
    parts = 3
    eta = 0.05

    rng = np.random.default_rng(77)
    inputs_np = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    labels = rng.integers(0, output_size, size=(batch_size, time_steps))
    targets_np = np.zeros((batch_size, output_size, time_steps), dtype=np.float32)
    for t in range(time_steps):
        targets_np[np.arange(batch_size), labels[:, t], t] = 1.0

    inputs = torch.from_numpy(inputs_np)
    targets = torch.from_numpy(targets_np)
    schedule = build_chunk_schedule(time_steps, parts)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=123)
    params_init = {
        "W_xh": torch.from_numpy(params_np.W_xh),
        "W_hh": torch.from_numpy(params_np.W_hh),
        "b_h": torch.from_numpy(params_np.b_h.reshape(-1)),
        "W_hy": torch.from_numpy(params_np.W_hy),
        "b_y": torch.from_numpy(params_np.b_y.reshape(-1)),
    }

    ref_snapshots = reference_seq2seq_chunk_updates(
        inputs, targets, params_init, schedule, eta=eta
    )

    model = StrictFPTTClassifier(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        parts=parts,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,
        lmbda=0.0,
        oracle_momentum=1.0,
        warmup_epochs=0,
        oracle_id="seq2seq_check",
        label_mode="all",
        use_oracle=False,
        optimizer_cls=torch.optim.SGD,
        device="cpu",
    )
    apply_rnn_parameters(model, params_np)
    strict_snapshots = strict_seq2seq_updates(model, inputs, targets, schedule)

    compare_snapshots(ref_snapshots, strict_snapshots, tol=1e-5)


def test_fptt_regression() -> None:
    print("[TEST] StrictFPTT regressor vs reference chunked update")
    set_seed(9)

    batch_size = 4
    input_size = 3
    hidden_size = 6
    output_size = 2
    time_steps = 5
    parts = 2
    eta = 0.04

    rng = np.random.default_rng(99)
    inputs_np = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    targets_np = rng.standard_normal((batch_size, output_size, time_steps)).astype(np.float32)

    inputs = torch.from_numpy(inputs_np)
    targets = torch.from_numpy(targets_np)
    schedule = build_chunk_schedule(time_steps, parts)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=321)
    params_init = {
        "W_xh": torch.from_numpy(params_np.W_xh),
        "W_hh": torch.from_numpy(params_np.W_hh),
        "b_h": torch.from_numpy(params_np.b_h.reshape(-1)),
        "W_hy": torch.from_numpy(params_np.W_hy),
        "b_y": torch.from_numpy(params_np.b_y.reshape(-1)),
    }

    ref_snapshots = reference_regression_chunk_updates(
        inputs, targets, params_init, schedule, eta=eta
    )

    model = StrictFPTTRegressor(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        parts=parts,
        clip=0.0,
        alpha=0.1,
        beta=0.5,
        rho=1.0,
        lmbda=0.0,
        optimizer_cls=torch.optim.SGD,
        device="cpu",
    )
    apply_rnn_parameters(model, params_np)
    strict_snapshots = strict_regression_updates(model, inputs, targets, schedule)

    compare_snapshots(ref_snapshots, strict_snapshots, tol=1e-5)


def main() -> None:
    test_fptt_seq2seq()
    test_fptt_regression()
    print("OK")


if __name__ == "__main__":
    main()
