from __future__ import annotations

import sys
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.standard_eprop import StandardEPropRNN
from methods.shared_rnn_utils import initialize_rnn_parameters, extract_rnn_parameters


def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)


def to_torch(params_np) -> Dict[str, torch.Tensor]:
    return {
        "W_xh": torch.from_numpy(params_np.W_xh).float(),
        "W_hh": torch.from_numpy(params_np.W_hh).float(),
        "b_h": torch.from_numpy(params_np.b_h).float(),
        "W_hy": torch.from_numpy(params_np.W_hy).float(),
        "b_y": torch.from_numpy(params_np.b_y).float(),
    }


def clone_params(params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    return {key: value.clone() for key, value in params.items()}


def simulate_eprop(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    h_prev_init: torch.Tensor,
    *,
    eta: float,
    decay_lambda: float,
) -> tuple[List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor], torch.Tensor]:
    batch_size, input_size, time_steps = inputs.shape
    hidden_size = params_init["W_hh"].shape[0]

    e_W_xh = torch.zeros((batch_size, hidden_size, input_size), dtype=torch.float32)
    e_W_hh = torch.zeros((batch_size, hidden_size, hidden_size), dtype=torch.float32)
    e_b_h = torch.zeros((batch_size, hidden_size), dtype=torch.float32)

    params = clone_params(params_init)
    h_prev = h_prev_init.clone()
    history: List[Dict[str, torch.Tensor]] = []

    for t in range(time_steps):
        snapshot = clone_params(params)
        I_t = inputs[:, :, t].T
        y_target_t = targets[:, :, t].T

        x_t = snapshot["W_hh"] @ h_prev + snapshot["W_xh"] @ I_t + snapshot["b_h"]
        h_t = torch.tanh(x_t)
        y_hat_t = snapshot["W_hy"] @ h_t + snapshot["b_y"]

        error_out = y_hat_t - y_target_t
        dL_dyhat = error_out
        l_t = snapshot["W_hy"].T @ dL_dyhat
        psi = (1.0 - h_t**2).T

        h_prev_T = h_prev.T
        I_t_T = I_t.T

        e_W_hh = decay_lambda * e_W_hh + psi[:, :, None] * h_prev_T[:, None, :]
        e_W_xh = decay_lambda * e_W_xh + psi[:, :, None] * I_t_T[:, None, :]
        e_b_h = decay_lambda * e_b_h + psi

        l_batch = l_t.T
        dW_hh_t = torch.einsum("bi,bij->ij", l_batch, e_W_hh) / batch_size
        dW_xh_t = torch.einsum("bi,bij->ij", l_batch, e_W_xh) / batch_size
        db_h_t = torch.sum(l_batch * e_b_h, dim=0, keepdim=True).T / batch_size
        dW_hy_t = (dL_dyhat @ h_t.T) / batch_size
        db_y_t = torch.sum(dL_dyhat, dim=1, keepdim=True) / batch_size

        history.append(
            {
                "params": snapshot,
                "h_prev": h_prev.clone(),
                "grads": {
                    "W_xh": dW_xh_t,
                    "W_hh": dW_hh_t,
                    "b_h": db_h_t,
                    "W_hy": dW_hy_t,
                    "b_y": db_y_t,
                },
            }
        )

        params["W_xh"] = params["W_xh"] - eta * dW_xh_t
        params["W_hh"] = params["W_hh"] - eta * dW_hh_t
        params["b_h"] = params["b_h"] - eta * db_h_t
        params["W_hy"] = params["W_hy"] - eta * dW_hy_t
        params["b_y"] = params["b_y"] - eta * db_y_t
        h_prev = h_t

    return history, params, h_prev


def autograd_detach_grads(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    history: List[Dict[str, torch.Tensor]],
) -> List[Dict[str, torch.Tensor]]:
    grads_auto: List[Dict[str, torch.Tensor]] = []
    time_steps = inputs.shape[2]

    for t in range(time_steps):
        params = history[t]["params"]
        h_prev = history[t]["h_prev"].detach()

        W_xh = params["W_xh"].clone().detach().requires_grad_(True)
        W_hh = params["W_hh"].clone().detach().requires_grad_(True)
        b_h = params["b_h"].clone().detach().requires_grad_(True)
        W_hy = params["W_hy"].clone().detach().requires_grad_(True)
        b_y = params["b_y"].clone().detach().requires_grad_(True)

        I_t = inputs[:, :, t].T
        y_target_t = targets[:, :, t].T

        x_t = W_hh @ h_prev + W_xh @ I_t + b_h
        h_t = torch.tanh(x_t)
        y_hat_t = W_hy @ h_t + b_y

        error_out = y_hat_t - y_target_t
        loss_t = 0.5 * torch.mean(torch.sum(error_out**2, dim=0))

        grads = torch.autograd.grad(loss_t, (W_xh, W_hh, b_h, W_hy, b_y))
        grads_auto.append(
            {
                "W_xh": grads[0],
                "W_hh": grads[1],
                "b_h": grads[2],
                "W_hy": grads[3],
                "b_y": grads[4],
            }
        )

    return grads_auto


def max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(torch.max(torch.abs(a - b)).item())


def compare_grad_lists(
    grads_ref: List[Dict[str, torch.Tensor]],
    grads_auto: List[Dict[str, torch.Tensor]],
    tol: float,
) -> None:
    if len(grads_ref) != len(grads_auto):
        raise AssertionError("Gradient list lengths do not match.")
    for t, (g_ref, g_auto) in enumerate(zip(grads_ref, grads_auto)):
        for name in g_ref.keys():
            diff = max_abs_diff(g_ref[name], g_auto[name])
            print(f"step {t} {name} max abs diff = {diff:.6e}")
            if diff > tol:
                raise AssertionError(f"Gradient mismatch at step {t} for {name}: {diff}")


def compare_param_sets(
    params_a: Dict[str, torch.Tensor],
    params_b: Dict[str, torch.Tensor],
    tol: float,
) -> None:
    for name in params_a.keys():
        diff = max_abs_diff(params_a[name], params_b[name])
        print(f"final {name} max abs diff = {diff:.6e}")
        if diff > tol:
            raise AssertionError(f"Parameter mismatch for {name}: {diff}")


def main() -> None:
    set_seed(7)

    batch_size = 3
    input_size = 4
    hidden_size = 5
    output_size = 2
    time_steps = 6
    eta = 0.05
    decay_lambda = 0.0

    params_np = initialize_rnn_parameters(
        input_size, hidden_size, output_size, gain=1.0, seed=123
    )
    params_init = to_torch(params_np)

    rng = np.random.default_rng(10)
    inputs_np = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    targets_np = rng.standard_normal((batch_size, output_size, time_steps)).astype(np.float32)

    inputs = torch.from_numpy(inputs_np)
    targets = torch.from_numpy(targets_np)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32)

    history, params_final_ref, _ = simulate_eprop(
        inputs,
        targets,
        params_init,
        h_prev,
        eta=eta,
        decay_lambda=decay_lambda,
    )

    grads_auto = autograd_detach_grads(inputs, targets, history)
    grads_ref = [entry["grads"] for entry in history]
    print("Comparing per-step gradients vs stop-gradient autograd (decay_lambda=0).")
    compare_grad_lists(grads_ref, grads_auto, tol=1e-5)

    model = StandardEPropRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        decay_lambda=decay_lambda,
        feedback="symmetric",
        loss_mode="mse",
        max_grad_norm=0.0,
        device="cpu",
    )
    model.W_xh = params_init["W_xh"].clone()
    model.W_hh = params_init["W_hh"].clone()
    model.b_h = params_init["b_h"].clone()
    model.W_hy = params_init["W_hy"].clone()
    model.b_y = params_init["b_y"].clone()

    _loss, _ = model.train_batch(inputs_np, targets_np, h_prev.numpy())
    params_after = extract_rnn_parameters(model)
    params_final_model = {
        "W_xh": torch.from_numpy(params_after.W_xh),
        "W_hh": torch.from_numpy(params_after.W_hh),
        "b_h": torch.from_numpy(params_after.b_h),
        "W_hy": torch.from_numpy(params_after.W_hy),
        "b_y": torch.from_numpy(params_after.b_y),
    }

    print("Comparing final weights vs StandardEPropRNN update.")
    compare_param_sets(params_final_ref, params_final_model, tol=1e-5)
    print("OK")


if __name__ == "__main__":
    main()
