import time
import os
import sys
import argparse
import numpy as np
import torch
from pypapi import events
from pypapi import papi_high as high
from pypapi import papi_low as papi

# from k_conv_basis_utils import recover_k_conv, conv_with_fft_matrix, conv_with_fft
# from test import test_recovered_b, test_recovered_b_tilde

import matplotlib.pyplot as plt

import gc
from memory_profiler import profile
gc.collect()

from pdb import set_trace as pds

head_idx = 2



def naive_exact_attention_score(Q, K, V):
    n, d = Q.shape
    QK = Q @ K.T  # n^2 * d FLOPs
    QK = QK / np.sqrt(V.shape[-1]) ## added by zhuoyan
    mask = np.tril(np.ones_like(QK)).astype(np.float64)
    QK_exp = np.exp(QK)  # n^2 FLOPs
    QK_exp_mask = mask * QK_exp  # n^2 FLOPs
    row_sums = np.sum(QK_exp_mask, axis=1, keepdims=True)  # n^2 FLOPs
    D_stable_inv = 1 / row_sums  # n FLOPs
    attention = D_stable_inv * QK_exp_mask  # n^2 FLOPs
    result = attention @ V  # n^2 * d FLOPs
    return result, 2 * n**2 * d + 4 * n**2 + n

def conv_with_fft(a, x, shift=0):
    n = a.shape[0] 
    n = n - shift
    a_padded = np.zeros(2 * n, dtype=np.float32)
    x_padded = np.zeros(2 * n, dtype=np.float32)
    a_padded[:n] = a[:n]
    x_padded[:n] = x[-n:]

    result = np.zeros_like(a, dtype=np.float32)
    result[-n:] = np.fft.ifft(np.fft.fft(a_padded) * np.fft.fft(x_padded)).real[:n]
    return result


def conv_with_fft_matrix(a, X, shift=0):
    n, d = X.shape
    result_matrix = np.zeros_like(X, dtype=np.float32)
    for i in range(d):
        result_matrix[:, i] = conv_with_fft(a, X[:, i], shift=shift)
    return result_matrix


def recover_k_conv(Q, K, k, T, delta, epsilon):
    flops = 0
    n, d = Q.shape
    v = np.zeros(T, dtype=np.float32)  # Initial vector v
    u = np.zeros(n, dtype=np.float32)  # Initial vector u
    s = 0  # Initial index s
    t = n - T

    # print(f"k: {k}, T: {T}, delta: {delta}, epsilon: {epsilon}")

    m = np.zeros(k, dtype=int)
    b = np.zeros((k, n), dtype=np.float32)

    # Caculate the first b
    b[0, :] = Q @ K.T[:, 0]
    # Initial matrix multiplication
    flops += n * d  # Q @ K.T[:, 0]

    m[0] = n
    v += b[0, :T]
    flops += T
    u += b[0, :]
    flops += n
    for i in range(1, k):
        s += 1
        # s = binary_search(Q, K, k, T, delta, epsilon, v, s, t)
        m[i] = n - s
        if m[i] <= 0:
            break
        H_s = Q @ (K.T)[:,s]
        flops += n * d  # Q @ (K.T)[:,s]

        b[i, :m[i]] = H_s[s:s + m[i]] - u[:m[i]]
        flops += m[i]
        v += b[i, :T]
        flops += T
        u += b[i, :]
        flops += n 

    b_tilde, flops_2 = get_b_tilde_from_b(b)
    flops += flops_2

    return b_tilde, m, b, flops

def get_b_tilde_from_b(b):
    flops = 0
    b_tilde = np.ones_like(b, dtype=np.float32) # exp(0)
    k, n = b.shape
    sum_b_r = np.zeros(n, dtype=np.float32)
    sum_b_r_minus_1 = np.zeros(n, dtype=np.float32)
    for i in range(k):
        if i == 0:
            sum_b_r += b[i]
            flops += n
            b_tilde[i, :] = np.exp(sum_b_r)
            flops += n
        else:
            sum_b_r += b[i]
            flops += n
            sum_b_r_minus_1 += b[i - 1]
            flops += n
            b_tilde[i, :] = np.exp(sum_b_r) - np.exp(sum_b_r_minus_1)
            flops += 3 * n
    return b_tilde, flops

def k_conv_basis_attention_score(Q, K, V, k, T, delta, epsilon):
    flops = 0  # Initialize FLOP counter
    n, d = Q.shape
    Q = Q / np.sqrt(V.shape[-1]) ## added by zhuoyan
    b_tilde, m, b, flops_recover = recover_k_conv(Q, K, k=k, T=T, delta=delta, epsilon=epsilon)
    flops += flops_recover
    QKV_approx = np.zeros_like(Q, dtype=np.float64)
    for i in range(k):
        QKV_approx += conv_with_fft_matrix(b_tilde[i, :], V, shift=n - m[i])
        flops += 3 * n * np.log(n) * d  # Estimate for FFT-based convolution

    D_approx = np.zeros(n, dtype=np.float64)
    for i in range(k):
        D_approx += conv_with_fft(b_tilde[i, :], np.ones(n), shift=n - m[i])
        flops += 3 * n * np.log(n)  # Estimate for FFT-based convolution

    QKV_approx = np.expand_dims(D_approx ** -1, axis=1) * QKV_approx
    flops += 2 * n * d  # For the final multiplication and division

    return QKV_approx, flops


def generate_random_data(n, d):
    Q = np.random.randn(n, d).astype(np.float64)
    K = np.random.randn(n, d).astype(np.float64)
    V = np.random.randn(n, d).astype(np.float64)
    return Q, K, V


def parse_args():
    parser = argparse.ArgumentParser(description="text encoder on vision language model")
    parser.add_argument(
        '--k', help='number of basis functions for k-conv', type = int, default=5,
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    
    # Parameters
    n = 500  # number of tokens
    d = 64    # embedding dimension
    T = 1
    delta = 1e-9
    epsilon = 0
    # k conv basis parameters
    k = args.k

    # Generate random input data
    Q, K, V = generate_random_data(n, d)
    
    start_time = time.time()
    QKV, flops_naive = naive_exact_attention_score(Q, K, V)

    end_time = time.time()
    time_naive = end_time - start_time

    start_time = time.time()
    QKV_approx, flops_approx = k_conv_basis_attention_score(Q, K, V, k=k, T=T, delta=delta, epsilon=epsilon)
    end_time = time.time()
    time_approx = end_time - start_time

    # pds()
    

    relative_diff = np.linalg.norm(QKV - QKV_approx, ord='fro') / np.linalg.norm(QKV, ord='fro')

    # print("relative_diff:", relative_diff)
    # print("time_naive:", time_naive)
    # print("time_approx:", time_approx)
    # print("flops_naive:", flops_naive)
    # print("flops_approx:", flops_approx)

    # # save flops_approx and time_approx with name k
    # torch.save(QKV, f'QKV_{head_idx}.pth')
    # torch.save(relative_diff, f'relative_diff_{k}_{head_idx}.pth')
    # torch.save(QKV_approx, f'QKV_approx_{k}_{head_idx}.pth')
    # torch.save(time_approx, f'time_approx_{k}_{head_idx}.pth')
    # torch.save(flops_approx, f'flops_approx_{k}_{head_idx}.pth')

    # Print results
    print(f"Input shape: {n} tokens, {d} dimensions")
    print(f"K-conv parameters: k={k}, T={T}, delta={delta}, epsilon={epsilon}")
    print(f"Relative difference: {relative_diff:.6f}")
    # print(f"Naive attention time: {time_naive:.6f} seconds")
    # print(f"K-conv attention time: {time_approx:.6f} seconds")
    # print(f"Naive attention FLOPs: {flops_naive}")
    # print(f"K-conv attention FLOPs: {flops_approx}")
    print(f"time naive/fft: {time_naive / time_approx:.2f}")
    print(f"FLOP naive/fft: {flops_naive / flops_approx:.2f}")



if __name__ == '__main__':
    main()


