import torch
import numpy as np
import unittest

import sys
sys.path.insert(0, "/home/ubuntu/projects/conv_basis")
from src.model_gpt2 import _conv_attn_per_head
from pdb import set_trace as pds

##############################################################################################################
###################  numpy conv attn
def get_b_tilde_from_b(b):
    b_tilde = np.ones_like(b, dtype=np.float32) # exp(0)
    k, n = b.shape
    sum_b_r = np.zeros(n, dtype=np.float32)
    sum_b_r_minus_1 = np.zeros(n, dtype=np.float32)
    for i in range(k):
        if i == 0:
            sum_b_r += b[i]
            b_tilde[i, :] = np.exp(sum_b_r)
        else:
            sum_b_r += b[i]
            sum_b_r_minus_1 += b[i - 1]
            b_tilde[i, :] = np.exp(sum_b_r) - np.exp(sum_b_r_minus_1)
    return b_tilde

def recover_k_conv(Q, K, k, T, delta, epsilon):
    n, d = Q.shape
    v = np.zeros(T, dtype=np.float32)  # Initial vector v
    u = np.zeros(n, dtype=np.float32)  # Initial vector u
    s = 0  # Initial index s
    t = n - T

    m = np.zeros(k, dtype=int)
    b = np.zeros((k, n), dtype=np.float32)

    # Caculate the first b
    b[0, :] = Q @ K.T[:, 0]
    m[0] = n
    v += b[0, :T]
    u += b[0, :]
    for i in range(1, k):
        s += 1
        # s = binary_search(Q, K, k, T, delta, epsilon, v, s, t)
        m[i] = n - s
        if m[i] <= 0:
            break
        H_s = Q @ (K.T)[:,s]
        b[i, :m[i]] = H_s[s:s + m[i]] - u[:m[i]]
        v += b[i, :T]
        u += b[i, :]
    b_tilde = get_b_tilde_from_b(b)
    return b_tilde, m, b

def conv_with_fft(a, x, shift=0):
    n = a.shape[0] 
    n = n - shift
    a_padded = np.zeros(2 * n, dtype=np.float32)
    x_padded = np.zeros(2 * n, dtype=np.float32)
    a_padded[:n] = a[:n]
    x_padded[:n] = x[-n:]

    result = np.zeros_like(a, dtype=np.float32)
    result[-n:] = np.fft.ifft(np.fft.fft(a_padded) * np.fft.fft(x_padded)).real[:n]
    return result


def conv_with_fft_matrix(a, X, shift=0):
    n, d = X.shape
    result_matrix = np.zeros_like(X, dtype=np.float32)
    for i in range(d):
        result_matrix[:, i] = conv_with_fft(a, X[:, i], shift=shift)
    return result_matrix

def k_conv_basis_attention_score(Q, K, V, k, T, delta, epsilon):
    n = Q.shape[0]
    # Assuming Q is already a NumPy array
    Q = Q / np.sqrt(V.shape[-1]) ## added by zhuoyan
    b_tilde, m, b= recover_k_conv(Q, K, k=k, T=T, delta=delta, epsilon=epsilon)
    QKV_approx= np.zeros_like(Q, dtype=np.float64)
    for i in range(k):
        QKV_approx += conv_with_fft_matrix(b_tilde[i, :], V, shift=n - m[i])

    D_approx = np.zeros(n, dtype=np.float64)
    for i in range(k):
        D_approx += conv_with_fft(b_tilde[i, :], np.ones(n), shift=n - m[i])

    #return QKV_approx
    QKV_approx = np.expand_dims(D_approx ** -1, axis=1) * QKV_approx

    return QKV_approx

##############################################################################################################
###################### torch conv attn: from src.model_gpt2 import _conv_attn_per_head


class TestConvAttention(unittest.TestCase):
    def setUp(self):
        # Set random seed for reproducibility
        np.random.seed(42)
        torch.manual_seed(42)

    def test_conv_attention_implementations(self):
        # Test parameters
        seq_len, head_d = 100, 64
        k, T = 5, 1
        delta, epsilon = 0, 0

        # Generate random input data
        Q_np = np.random.randn(seq_len, head_d).astype(np.float32)
        K_np = np.random.randn(seq_len, head_d).astype(np.float32)
        V_np = np.random.randn(seq_len, head_d).astype(np.float32)

        Q_torch = torch.tensor(Q_np)
        K_torch = torch.tensor(K_np)
        V_torch = torch.tensor(V_np)

        # Run NumPy implementation
        result_np = k_conv_basis_attention_score(Q_np, K_np, V_np, k, T, delta, epsilon)

        # Run PyTorch implementation
        result_torch = _conv_attn_per_head(Q_torch, K_torch, V_torch, k=k, T=T, delta=delta, epsilon=epsilon)

        # Convert PyTorch result to NumPy for comparison
        result_torch_np = result_torch.numpy()

        # Compare results
        np.testing.assert_allclose(result_np, result_torch_np, rtol=1e-4, atol=1e-4)
        print("PyTorch and NumPy implementations produce similar results.")

# Assuming the necessary functions and classes are defined elsewhere
# You'll need to import or define:
# - k_conv_basis_attention_score (NumPy version)
# - ConvAttention class with conv_attn_per_head method (PyTorch version)
# - recover_k_conv, conv_with_fft, and conv_with_fft_matrix for both NumPy and PyTorch

def main():
    np.random.seed(42)
    torch.manual_seed(42)

    # Test parameters
    seq_len, head_d = 100, 64
    k, T = 5, 1
    delta, epsilon = 0, 0

    # Generate random input data
    Q_np = np.random.randn(seq_len, head_d).astype(np.float32)
    K_np = np.random.randn(seq_len, head_d).astype(np.float32)
    V_np = np.random.randn(seq_len, head_d).astype(np.float32)

    # Convert to PyTorch tensors with full precision
    Q_torch = torch.from_numpy(Q_np)
    K_torch = torch.from_numpy(K_np)
    V_torch = torch.from_numpy(V_np)


    # Run NumPy implementation
    result_np = k_conv_basis_attention_score(Q_np, K_np, V_np, k, T, delta, epsilon)
    # result_np, m, b = recover_k_conv(Q_np, K_np, k, T, delta, epsilon)
    pds()



if __name__ == '__main__':
    # main()
    unittest.main()