from __future__ import annotations

import sys
from pathlib import Path
from typing import Dict

import numpy as np
import torch

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from methods.shared_rnn_utils import initialize_rnn_parameters
from task.common.sequence_core import TorchBPTTRNN


def set_seed(seed: int) -> None:
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))


def _torch_params_from_np(params_np, device: torch.device) -> Dict[str, torch.Tensor]:
    return {
        "W_xh": torch.as_tensor(params_np.W_xh, dtype=torch.float32, device=device),
        "W_hh": torch.as_tensor(params_np.W_hh, dtype=torch.float32, device=device),
        "b_h": torch.as_tensor(params_np.b_h, dtype=torch.float32, device=device),
        "W_hy": torch.as_tensor(params_np.W_hy, dtype=torch.float32, device=device),
        "b_y": torch.as_tensor(params_np.b_y, dtype=torch.float32, device=device),
    }


def _copy_params_into_bptt(model: TorchBPTTRNN, params_np) -> None:
    device = model.device
    with torch.no_grad():
        model.W_xh.copy_(torch.as_tensor(params_np.W_xh, dtype=torch.float32, device=device))
        model.W_hh.copy_(torch.as_tensor(params_np.W_hh, dtype=torch.float32, device=device))
        model.b_h.copy_(torch.as_tensor(params_np.b_h, dtype=torch.float32, device=device))
        model.W_hy.copy_(torch.as_tensor(params_np.W_hy, dtype=torch.float32, device=device))
        model.b_y.copy_(torch.as_tensor(params_np.b_y, dtype=torch.float32, device=device))


def _extract_params(model: TorchBPTTRNN) -> Dict[str, np.ndarray]:
    return {
        "W_xh": model.W_xh.detach().cpu().numpy(),
        "W_hh": model.W_hh.detach().cpu().numpy(),
        "b_h": model.b_h.detach().cpu().numpy(),
        "W_hy": model.W_hy.detach().cpu().numpy(),
        "b_y": model.b_y.detach().cpu().numpy(),
    }


def _param_max_abs_diff(a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]) -> float:
    diffs = []
    for name in a.keys():
        diffs.append(float(np.max(np.abs(a[name] - b[name]))))
    return float(max(diffs))


def _reference_full_bptt_sgd(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    h_prev: torch.Tensor,
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    time_normalization: bool,
) -> tuple[Dict[str, torch.Tensor], float]:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    time_steps = int(inputs.shape[2])

    if step_weights is not None:
        step_weights_t = step_weights.to(dtype=torch.float32)
        weight_norm = float(step_weights_t.sum().item())
    else:
        step_weights_t = None
        weight_norm = float(time_steps)
    weight_norm = max(weight_norm, 1.0)

    total_loss = torch.zeros((), device=inputs.device)
    h_state = h_prev.clone().detach()

    for t in range(time_steps):
        step_weight = step_weights_t[t] if step_weights_t is not None else 1.0
        I_t = inputs[:, :, t].T
        x_t = params["W_hh"] @ h_state + params["W_xh"] @ I_t + params["b_h"]
        h_state = torch.tanh(x_t)
        y_hat_t = params["W_hy"] @ h_state + params["b_y"]
        y_true_t = targets[:, :, t].T

        if loss_mode == "ce":
            log_probs = torch.log_softmax(y_hat_t, dim=0)
            loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
        else:
            error = y_hat_t - y_true_t
            loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

        if step_weights_t is not None:
            loss_t = loss_t * step_weight
        total_loss = total_loss + loss_t

    loss_scale = weight_norm if time_normalization else 1.0
    scaled_loss = total_loss / max(loss_scale, 1.0)
    grads = torch.autograd.grad(scaled_loss, tuple(params.values()))

    updated: Dict[str, torch.Tensor] = {}
    with torch.no_grad():
        for (name, value), grad in zip(params.items(), grads):
            updated[name] = (value - float(eta) * grad).detach().clone()

    return updated, float(scaled_loss.detach().cpu().item())


def _reference_tbptt_sgd(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    params_init: Dict[str, torch.Tensor],
    h_prev: torch.Tensor,
    *,
    eta: float,
    loss_mode: str,
    step_weights: torch.Tensor | None,
    tbptt_steps: int,
    time_normalization: bool,
) -> tuple[Dict[str, torch.Tensor], float]:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    params = {k: v.clone().detach().requires_grad_(True) for k, v in params_init.items()}
    time_steps = int(inputs.shape[2])
    tbptt_steps = int(tbptt_steps)
    if tbptt_steps <= 0:
        raise ValueError("tbptt_steps must be positive.")

    if step_weights is not None:
        step_weights_t = step_weights.to(dtype=torch.float32)
        weight_norm = float(step_weights_t.sum().item())
    else:
        step_weights_t = None
        weight_norm = float(time_steps)
    weight_norm = max(weight_norm, 1.0)

    total_loss_value = 0.0
    h_state = h_prev.clone().detach()

    for start in range(0, time_steps, tbptt_steps):
        end = min(time_steps, start + tbptt_steps)
        chunk_loss = torch.zeros((), device=inputs.device)
        for t in range(start, end):
            step_weight = step_weights_t[t] if step_weights_t is not None else 1.0
            I_t = inputs[:, :, t].T
            x_t = params["W_hh"] @ h_state + params["W_xh"] @ I_t + params["b_h"]
            h_state = torch.tanh(x_t)
            y_hat_t = params["W_hy"] @ h_state + params["b_y"]
            y_true_t = targets[:, :, t].T

            if loss_mode == "ce":
                log_probs = torch.log_softmax(y_hat_t, dim=0)
                loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
            else:
                error = y_hat_t - y_true_t
                loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

            if step_weights_t is not None:
                loss_t = loss_t * step_weight
            chunk_loss = chunk_loss + loss_t

        if time_normalization:
            if step_weights_t is not None:
                chunk_weight_norm = float(step_weights_t[start:end].sum().item())
            else:
                chunk_weight_norm = float(end - start)
            chunk_weight_norm = max(chunk_weight_norm, 1.0)
        else:
            chunk_weight_norm = 1.0

        scaled_chunk_loss = chunk_loss / max(chunk_weight_norm, 1.0)
        grads = torch.autograd.grad(scaled_chunk_loss, tuple(params.values()))
        with torch.no_grad():
            new_params: Dict[str, torch.Tensor] = {}
            for (name, value), grad in zip(params.items(), grads):
                new_params[name] = (value - float(eta) * grad).detach().requires_grad_(True)
            params = new_params
        total_loss_value += float(chunk_loss.detach().cpu().item())
        h_state = h_state.detach()

    avg_loss_scale = weight_norm if time_normalization else 1.0
    avg_loss = total_loss_value / max(avg_loss_scale, 1.0)
    updated = {k: v.detach().clone() for k, v in params.items()}
    return updated, float(avg_loss)


def test_full_bptt_weighted_mse() -> None:
    print("[TEST] Full BPTT weighted MSE (no time normalization)")
    set_seed(101)
    device = torch.device("cpu")

    batch_size = 3
    input_size = 2
    hidden_size = 5
    output_size = 2
    time_steps = 6
    eta = 0.03

    rng = np.random.default_rng(5)
    inputs_np = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    targets_np = rng.standard_normal((batch_size, output_size, time_steps)).astype(np.float32)
    step_weights = np.linspace(0.25, 1.0, time_steps, dtype=np.float32)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=222)
    params_t = _torch_params_from_np(params_np, device=device)

    bptt = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="mse",
        max_grad_norm=0.0,
        tbptt_steps=None,
        time_normalization=False,
        seed=0,
        device=device,
    )
    bptt.step_weights = torch.as_tensor(step_weights, dtype=torch.float32, device=device)
    _copy_params_into_bptt(bptt, params_np)
    loss_model, _ = bptt.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    params_after_model = _extract_params(bptt)

    ref_after, loss_ref = _reference_full_bptt_sgd(
        torch.as_tensor(inputs_np, device=device),
        torch.as_tensor(targets_np, device=device),
        params_t,
        h_prev,
        eta=eta,
        loss_mode="mse",
        step_weights=bptt.step_weights,
        time_normalization=False,
    )
    params_after_ref = {k: v.detach().cpu().numpy() for k, v in ref_after.items()}
    diff = _param_max_abs_diff(params_after_model, params_after_ref)
    print(f"  diff={diff:.3e} loss(model)={loss_model:.6f} loss(ref)={loss_ref:.6f}")
    assert diff < 1e-6, f"Full BPTT weighted MSE mismatch: diff={diff}"
    assert abs(float(loss_model) - float(loss_ref)) < 1e-7, "Full BPTT weighted MSE loss mismatch."


def test_tbptt_weighted_ce() -> None:
    print("[TEST] TBPTT weighted CE (time normalization)")
    set_seed(202)
    device = torch.device("cpu")

    batch_size = 4
    input_size = 3
    hidden_size = 6
    output_size = 3
    time_steps = 7
    tbptt_steps = 3
    eta = 0.04

    rng = np.random.default_rng(9)
    inputs_np = rng.standard_normal((batch_size, input_size, time_steps)).astype(np.float32)
    labels = rng.integers(0, output_size, size=(batch_size, time_steps))
    targets_np = np.zeros((batch_size, output_size, time_steps), dtype=np.float32)
    for t in range(time_steps):
        targets_np[np.arange(batch_size), labels[:, t], t] = 1.0
    step_weights = np.array([1.0, 0.5, 0.2, 1.2, 0.8, 0.6, 0.4], dtype=np.float32)
    h_prev = torch.zeros((hidden_size, batch_size), dtype=torch.float32, device=device)

    params_np = initialize_rnn_parameters(input_size, hidden_size, output_size, gain=1.0, seed=333)
    params_t = _torch_params_from_np(params_np, device=device)

    bptt = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        eta=eta,
        loss_mode="ce",
        max_grad_norm=0.0,
        tbptt_steps=tbptt_steps,
        time_normalization=True,
        seed=0,
        device=device,
    )
    bptt.step_weights = torch.as_tensor(step_weights, dtype=torch.float32, device=device)
    _copy_params_into_bptt(bptt, params_np)
    loss_model, _ = bptt.train_batch(inputs_np, targets_np, h_prev.cpu().numpy())
    params_after_model = _extract_params(bptt)

    ref_after, loss_ref = _reference_tbptt_sgd(
        torch.as_tensor(inputs_np, device=device),
        torch.as_tensor(targets_np, device=device),
        params_t,
        h_prev,
        eta=eta,
        loss_mode="ce",
        step_weights=bptt.step_weights,
        tbptt_steps=tbptt_steps,
        time_normalization=True,
    )
    params_after_ref = {k: v.detach().cpu().numpy() for k, v in ref_after.items()}
    diff = _param_max_abs_diff(params_after_model, params_after_ref)
    print(f"  diff={diff:.3e} loss(model)={loss_model:.6f} loss(ref)={loss_ref:.6f}")
    assert diff < 1e-6, f"TBPTT weighted CE mismatch: diff={diff}"
    assert abs(float(loss_model) - float(loss_ref)) < 1e-7, "TBPTT weighted CE loss mismatch."


def main() -> None:
    test_full_bptt_weighted_mse()
    test_tbptt_weighted_ce()
    print("OK")


if __name__ == "__main__":
    main()
