import os
import sys
from copy import deepcopy
from pathlib import Path

import torch

# Ensure the local fla package is importable
ROOT = Path(__file__).parent
FLA_ROOT = ROOT / "flash-linear-attention"
if str(FLA_ROOT) not in sys.path:
    sys.path.insert(0, str(FLA_ROOT))

from fla.modules.rotary import SelectiveRoPE, SelectiveRoPEFast  # noqa: E402

from original_selective_rope import OGSelectiveRoPE


def _run_once(device: torch.device, dtype: torch.dtype = torch.float32, skip_conv_cumsum: bool = True):
    # Shapes
    B = 2
    T = 128
    H = 4
    D = 64  # D per head; must satisfy D <= 256

    m_og = OGSelectiveRoPE(
        d_state=D, 
        num_heads=H,
        d_conv=4,
        device=device,
        dtype=dtype,
        skip_conv_cumsum=skip_conv_cumsum,
    ).to(device)

    # Module (Python reference)
    m_py = SelectiveRoPE(
        head_dim=D,
        num_heads=H,
        d_conv=4,
        device=device,
        dtype=dtype,
        skip_conv_cumsum=skip_conv_cumsum,
    ).to(device)

    # Module (Triton)
    m_tri = SelectiveRoPEFast(
        head_dim=D,
        num_heads=H,
        d_conv=4,
        device=device,
        dtype=dtype,
        skip_conv_cumsum=skip_conv_cumsum,
    ).to(device)
    
    # Ensure identical parameters across OG/Python/Triton for strict parity
    # OG's phi_bias was created via `.to(...)` on an nn.Parameter, which makes it
    # a non-parameter tensor and thus missing from its state_dict. Add it back.
    og_sd = m_og.state_dict()
    # if 'phi_bias' not in og_sd:
    #     og_sd['phi_bias'] = m_og.phi_bias
    m_py.load_state_dict(og_sd)
    m_tri.load_state_dict(og_sd)
    # Cast to the same dtype/device as the OG module/testing dtype
    m_py.to(device=device, dtype=dtype)
    m_tri.to(device=device, dtype=dtype)
    # m_tri.load_state_dict(m_py.state_dict())
    
    # Inputs
    q = torch.randn(B, T, H, D, device=device, dtype=dtype)
    k = torch.randn(B, T, H, D, device=device, dtype=dtype)
    q, k = map(lambda x: x.to(device).requires_grad_(), [q, k])

    # Forward and backward
    q.grad = k.grad = None
    q_out_og, k_out_og = m_og(q.clone(), k.clone())
    (q_out_og.sum() + k_out_og.sum()).backward(retain_graph=True)
    og_dq, og_dk = q.grad, k.grad


    q.grad = k.grad = None
    q_out_py, k_out_py = m_py(q.clone(), k.clone())
    (q_out_py.sum() + k_out_py.sum()).backward(retain_graph=True)
    python_dq, python_dk = q.grad, k.grad

    
    q.grad = k.grad = None
    q_out_tri, k_out_tri = m_tri(q.clone(), k.clone())
    (q_out_tri.sum() + k_out_tri.sum()).backward(retain_graph=True)
    triton_dq, triton_dk = q.grad, k.grad

    # Compare outputs
    assert torch.allclose(q_out_og, q_out_tri, rtol=1e-4, atol=1e-5), (
        f"q_out mismatch: max abs {torch.max(torch.abs(q_out_og - q_out_tri)).item()}"
    )
    assert torch.allclose(k_out_og, k_out_tri, rtol=1e-4, atol=1e-5), (
        f"k_out mismatch: max abs {torch.max(torch.abs(k_out_og - k_out_tri)).item()}"
    )
    assert torch.allclose(og_dq, triton_dq, rtol=1e-4, atol=1e-5), (
        f"q.grad mismatch: max abs {torch.max(torch.abs(og_dq - triton_dq)).item()}"
    )
    assert torch.allclose(og_dk, triton_dk, rtol=1e-4, atol=1e-5), (
        f"k.grad mismatch: max abs {torch.max(torch.abs(og_dk - triton_dk)).item()}"
    )
    assert torch.allclose(q_out_py, q_out_tri, rtol=1e-4, atol=1e-5), (
        f"q_out mismatch: max abs {torch.max(torch.abs(q_out_py - q_out_tri)).item()}"
    )
    assert torch.allclose(k_out_py, k_out_tri, rtol=1e-4, atol=1e-5), (
        f"k_out mismatch: max abs {torch.max(torch.abs(k_out_py - k_out_tri)).item()}"
    )
    assert torch.allclose(python_dq, triton_dq, rtol=1e-4, atol=1e-5), (
        f"q.grad mismatch: max abs {torch.max(torch.abs(python_dq - triton_dq)).item()}"
    )
    assert torch.allclose(python_dk, triton_dk, rtol=1e-4, atol=1e-5), (
        f"k.grad mismatch: max abs {torch.max(torch.abs(python_dk - triton_dk)).item()}"
    )

    # Compare parameter grads
    python_grads = {n: p.grad for n, p in m_py.named_parameters() if p.grad is not None}
    triton_grads = {n: p.grad for n, p in m_tri.named_parameters() if p.grad is not None}
    assert python_grads.keys() == triton_grads.keys(), (
        "Parameter sets differ between Python and Triton modules"
    )
    for name in python_grads:
        assert torch.allclose(python_grads[name], triton_grads[name], rtol=1e-4, atol=1e-5), (
            f"Param grad mismatch for {name}: max abs {torch.max(torch.abs(python_grads[name] - triton_grads[name])).item()}"
        )



def test_selective_rope_cuda_parity():
    if not torch.cuda.is_available():
        print("CUDA not available; skipping Triton parity test")
        return
    device = torch.device("cuda")
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    _run_once(device=device, dtype=torch.float32, skip_conv_cumsum=True)
    _run_once(device=device, dtype=torch.float32, skip_conv_cumsum=False)


if __name__ == "__main__":
    # Simple runner outside of pytest
    test_selective_rope_cuda_parity()
    print("SelectiveRoPE tests passed.")
