import torch
import numpy as np
import unittest

import sys
sys.path.insert(0, "/home/ubuntu/projects/conv_basis")
from src.model_gpt2 import _conv_attn_per_head, get_b_tilde_from_b, recover_k_conv
from pdb import set_trace as pds


def main():
    np.random.seed(42)
    torch.manual_seed(42)

    # Test parameters
    seq_len, head_d = 100, 64
    k, T = 5, 1
    delta, epsilon = 0, 0

    # Generate random input data
    Q_np = np.random.randn(seq_len, head_d).astype(np.float32)
    K_np = np.random.randn(seq_len, head_d).astype(np.float32)
    V_np = np.random.randn(seq_len, head_d).astype(np.float32)

    # Convert to PyTorch tensors with full precision
    Q_torch = torch.from_numpy(Q_np)
    K_torch = torch.from_numpy(K_np)
    V_torch = torch.from_numpy(V_np)

    # Run PyTorch implementation
    result_torch = _conv_attn_per_head(Q_torch, K_torch, V_torch, k=k, T=T, delta=delta, epsilon=epsilon)
    pds()



if __name__ == '__main__':
    main()
