import torch
import mustafar
from compression import convert_tensor_batched
import torch.nn.functional as F
import time


def dh_prune_key(key_states: torch.Tensor, target_sparsity=None):
        if target_sparsity is None:
            target_sparsity = self.k_sparsity
        """
        Performs magnitude-based pruning along the hidden dimension.
        
        Args:
            key_states (torch.Tensor): Tensor of shape [batch_size, num_heads, tokens, hidden_dim].
            target_sparsity (float): Fraction of elements to prune per vector (between 0 and 1).
            
        Returns:
            torch.Tensor: Pruned tensor with the same shape, with values pruned to zero.
        """
        assert 0 <= target_sparsity < 1, "Target sparsity must be between 0 and 1"

        # Get the shape of key_states
        H, T, D = key_states.shape  # Batch size, number of heads, tokens, hidden dimension

        # Compute the number of elements to keep per vector (hidden dimension)
        #num_to_keep = max(1, int((1 - target_sparsity) * D))
        num_to_keep = max(1, int((target_sparsity) * D))
        #if DEBUG: print("NUM TO KEEP for Key: ", num_to_keep)

        # Flatten along batch, head, and tokens, keeping only the hidden_dim axis separate
        key_states_flat = key_states.reshape(-1, D)  # Shape: [(B * H * T), D]

        # Compute pruning threshold per vector
        threshold_values, _ = torch.kthvalue(torch.abs(key_states_flat), num_to_keep, dim=-1, keepdim=True)

        # Create a mask: Keep only values larger than or equal to the threshold
        mask = torch.abs(key_states_flat) >= threshold_values

        # Apply the mask (zero out pruned elements)
        pruned_key_states = key_states_flat * mask

        #if DEBUG: print("Debug: -- sparsity of just pruned key: ", self.calculate_sparsity(pruned_key_states.view(B, H, T, D)))

        # Reshape back to original dimensions
        return pruned_key_states.view(H, T, D)


#------------------------------- Correctness Check --------------------------------
print("---------------------------------- Correctness Check --------------------------------")

Token_length = 4096
Model_depth = 128
Batch_size = 32

KV_prefill = torch.randn(Batch_size, Token_length, Model_depth, dtype=torch.float16).to("cuda")
vector = torch.randn(Batch_size, 1, Model_depth, dtype=torch.float16).to("cuda")
#vector = torch.randn(32, 4096, 1, dtype=torch.float16).to("cuda")
print("Finished intialization")

KV_prefill_to_compress_50 = dh_prune_key(KV_prefill, 0.5)

dense_result = torch.matmul(KV_prefill_to_compress_50, vector.transpose(1, 2)) #[32, T, D] x [32, 1, D]

bitmaps, accum_counts, packed_not = convert_tensor_batched(KV_prefill_to_compress_50)

#print("accum_counts shape: ", accum_counts.shape)

NZ_Offset = torch.zeros(Batch_size, dtype=torch.int32).to("cuda")
#tile_size = 4096*4096 // 64
for i in range(1, Batch_size):
    NZ_Offset[i] = NZ_Offset[i-1] + accum_counts[i-1][-1] // 4 #you are causing me a lot of pain, my friend. 

#print("NZ_Offset: ", NZ_Offset)

padded_vector = F.pad(vector, (0, 0, 0, 7)).contiguous()  # (0, 0) for last dim, (0, 7) for second-to-last dim //column-major for kernel. 

mustafar_result = mustafar.mustafar_key_formulation(bitmaps, torch.cat(packed_not), accum_counts, NZ_Offset, padded_vector, Token_length, Model_depth, Batch_size)

#print("Input padded vector shape: ", padded_vector.shape)
#print("dense_result shape: ", dense_result.shape)#
#print("mustafar_result shape: ", mustafar_result.shape) #column-major output. 

#print("mustafar_result: ", mustafar_result)

first_column = mustafar_result.transpose(1, 2)[:, :, 0:1].contiguous()

'''
print("first_column shape: ", first_column.shape)
print("first_column: ", first_column)

print("dense_result: ", dense_result)

# Add this after your existing code
print("\nComparing dense_result and first_column:")
print(f"Shapes - dense_result: {dense_result.shape}, first_column: {first_column.shape}")
'''

# Check if they're close within a reasonable tolerance
is_close = torch.allclose(dense_result, first_column, rtol=1e-3, atol=1e-3)
print(f"\nAre results equal? {is_close}")

if not is_close:
    # Calculate and print various difference metrics
    max_diff = (dense_result - first_column).abs().max().item()
    mean_diff = (dense_result - first_column).abs().mean().item()
    print(f"\nMaximum difference: {max_diff}")
    print(f"Mean difference: {mean_diff}")
    
    # Print some sample values for comparison
    print("\nSample values (first batch, first 5 elements):")
    print("dense_result:", dense_result[1, :5, 0])
    print("first_column:", first_column[1, :5, 0])
    
    # Find where the differences are largest
    diff = (dense_result - first_column).abs()
    max_diff_indices = torch.nonzero(diff == max_diff)
    print("\nLocation of maximum difference:")
    print(f"Indices: {max_diff_indices[0]}")
    print(f"dense_result value: {dense_result[max_diff_indices[0]]}")



#-----------------------Execution Time Check --------------------------------
print("---------------------------------- Execution Time Check --------------------------------")

Token_length = 4096
Model_depth = 128
Batch_size = 32

KV_prefill = torch.randn(Batch_size, Token_length, Model_depth, dtype=torch.float16).to("cuda")
vector = torch.randn(Batch_size, 1, Model_depth, dtype=torch.float16).to("cuda")
NZ_Offset = torch.zeros(Batch_size, dtype=torch.int32).to("cuda")


#Dense Computation Time 
torch.cuda.synchronize()
st = time.time()
dense_result = torch.matmul(KV_prefill, vector.transpose(1, 2)) #[32, T, D] x [32, 1, D]
torch.cuda.synchronize()
print(f'Dense Batched MM computation time: {(time.time() - st) * 1000} ms')



#50% sparsity
#prefill prune + compression time. 
torch.cuda.synchronize()
st = time.time()
KV_prefill_to_compress_50 = dh_prune_key(KV_prefill, 0.5)
bitmaps, accum_counts, packed_not = convert_tensor_batched(KV_prefill_to_compress_50)
#tile_size = 4096*4096 // 64
for i in range(1, Batch_size):
    NZ_Offset[i] = NZ_Offset[i-1] + accum_counts[i-1][-1] // 4

torch.cuda.synchronize()
p50_compression_time = (time.time() - st) * 1000
generation_length = 1024 
p50_compression_time_per_token = p50_compression_time / generation_length
print(f'50% sparsity: Prune and Compression time per token: {p50_compression_time_per_token} ms')

compressed_50 = [bitmaps, torch.cat(packed_not), accum_counts, NZ_Offset]

#computation time
torch.cuda.synchronize()
st = time.time()
padded_vector = F.pad(vector, (0, 0, 0, 7)).contiguous()  # (0, 0) for last dim, (0, 7) for second-to-last dim //column-major for kernel. 
mustafar_result = mustafar.mustafar_key_formulation(bitmaps, torch.cat(packed_not), accum_counts, NZ_Offset, padded_vector, Token_length, Model_depth, Batch_size)
first_column = mustafar_result.transpose(1, 2)[:, :, 0:1].contiguous()
torch.cuda.synchronize()
print(f'Batched SpMV computation time: {(time.time() - st) * 1000} ms')


#70% sparsity 
#prefill prune + compression time. 
torch.cuda.synchronize()
st = time.time()
KV_prefill_to_compress_70 = dh_prune_key(KV_prefill, 0.7)
bitmaps, accum_counts, packed_not = convert_tensor_batched(KV_prefill_to_compress_70)
#tile_size = 4096*4096 // 64
for i in range(1, Batch_size):
    NZ_Offset[i] = NZ_Offset[i-1] + accum_counts[i-1][-1] // 4
torch.cuda.synchronize()
p70_compression_time = (time.time() - st) * 1000
generation_length = 1024 
p70_compression_time_per_token = p70_compression_time / generation_length
print(f'70% sparsity: Prune and Compression time per token: {p70_compression_time_per_token} ms')

compressed_70 = [bitmaps, torch.cat(packed_not), accum_counts, NZ_Offset]

#computation time
torch.cuda.synchronize()
st = time.time()
padded_vector = F.pad(vector, (0, 0, 0, 7)).contiguous()  # (0, 0) for last dim, (0, 7) for second-to-last dim //column-major for kernel. 
mustafar_result = mustafar.mustafar_key_formulation(bitmaps, torch.cat(packed_not), accum_counts, NZ_Offset, padded_vector, Token_length, Model_depth, Batch_size)
first_column = mustafar_result.transpose(1, 2)[:, :, 0:1].contiguous()
torch.cuda.synchronize()
print(f'Batched SpMV computation time: {(time.time() - st) * 1000} ms')



#export dense KV cache, and compressed KV cache for 50% and 70% sparsity. 
torch.save(KV_prefill, "dense_kv_cache.pt")
torch.save(compressed_50, "compressed_kv_cache_50.pt")
torch.save(compressed_70, "compressed_kv_cache_70.pt")



