import torch 
#from transformers import LlamaConfig, AutoTokenizer, LlamaForCausalLM
#from huggingface_hub import login 
import triton
import triton.language as tl
import numpy as np
import gc



@triton.jit
def calculate_bitmap_kernel(input_ptr, bitmaps_ptr, counts_ptr, total_elems: tl.constexpr, shifts_ptr, M: tl.constexpr, N: tl.constexpr):
    #each block works on one 1x64 tile. 
    tile_id = tl.program_id(0)
    
    #naive version. 
    #base_idx = tile_id * 64
    #offsets = tl.arange(0, 64)
    #indices = base_idx + offsets

    #for key version: 
    # Calculate which 64x64 block we're in
    block_row = tile_id % N
    block_col = tile_id // N
    base_idx = block_row * M + block_col * 64
    offsets = tl.arange(0, 64)
    indices = base_idx + offsets

    #valid = indices < total_elems
    #vals = tl.load(input_ptr + indices, mask=valid, other=0.0)
    vals = tl.load(input_ptr + indices)
    bit_mask = tl.where(vals != 0.0, 1, 0)
    
    #use precomputed shifts for each position.
    #shifts = tl.load(shifts_ptr + tl.arange(0, 64))
    shifts = tl.load(shifts_ptr + offsets)
    
    #leave only the valid 1 bits. 
    bitmap = tl.sum(bit_mask * shifts, axis=0)
    cnt = tl.sum(bit_mask, axis=0)
    
    
    #pad to multiplies of 8 for kernel efficiency. (than halved later for packing into 32bits.)
    #cnt = (((cnt + 7) // 8) * 8) / 2 #original version of Coruscant
    #cnt = (((cnt + 7) // 8) * 8)  >> 1
    cnt = ((cnt + 7) & ~7) >> 1
    #cnt = ((cnt + 7) & ~7) 

    tl.store(bitmaps_ptr + tile_id, bitmap)
    tl.store(counts_ptr + tile_id, cnt)

@triton.jit
def calculate_bitmap_kernel_batched(
    input_ptr,        # [B * num_tiles_per_batch * 64]
    bitmaps_ptr,      # [B * num_tiles_per_batch]
    counts_ptr,       # [B * num_tiles_per_batch]
    total_elems: tl.constexpr,
    shifts_ptr,       # [64]
    stride_batch: tl.constexpr,  # = num_tiles_per_batch * 64
    M: tl.constexpr,
    N: tl.constexpr
):
    #batch_id = tl.program_id(0)
    #tile_id = tl.program_id(1)
    tile_id = tl.program_id(0)
    batch_id = tl.program_id(1)


    #naive version
    #tile_offset = batch_id * stride_batch + tile_id * 64
    #offsets = tl.arange(0, 64)
    #indices = tile_offset + offsets

    #for key version: 
    # Calculate which 64x64 block we're in, accounting for batch
    block_row = tile_id % N
    block_col = tile_id // N
    base_idx = batch_id * stride_batch + block_row * M + block_col * 64
    offsets = tl.arange(0, 64)
    indices = base_idx + offsets

    #valid = indices < total_elems
    #vals = tl.load(input_ptr + indices, mask=valid, other=0.0)
    vals = tl.load(input_ptr + indices)
    #bit_mask = (vals != 0.0).to(tl.int32)
    bit_mask = tl.where(vals != 0.0, 1, 0)
    shifts = tl.load(shifts_ptr + offsets)  # shifts_ptr[0:64]
    bitmap = tl.sum(bit_mask * shifts, axis=0)

    cnt = tl.sum(bit_mask, axis=0)
    #padding happens here.
    cnt = ((cnt + 7) & ~7) >> 1  # padded to nearest multiple of 8, then halved
    #cnt = (((cnt + 7) // 8) * 8) / 2

    #to address per-tile bitmap and counts. 
    flat_tile_index = batch_id * (stride_batch // 64) + tile_id
    tl.store(bitmaps_ptr + flat_tile_index, bitmap)
    tl.store(counts_ptr + flat_tile_index, cnt)


@triton.jit
def compress_kernel(input_ptr, bitmaps_ptr, counts_ptr, packed_not_ptr, total_elems: tl.constexpr, M: tl.constexpr, N: tl.constexpr):
    tile_id = tl.program_id(0)
    
    #naive version. 
    #base_idx = tile_id * 64
    #offsets = tl.arange(0, 64)
    #indices = base_idx + offsets

    #for key version: 
    # Calculate which 64x64 block we're in
    block_row = tile_id % N
    block_col = tile_id // N
    base_idx = block_row * M + block_col * 64
    offsets = tl.arange(0, 64)
    indices = base_idx + offsets


    bitmap = tl.load(bitmaps_ptr + tile_id)
    idx = tl.load(counts_ptr + tile_id)


    #load all 64 values. 
    #valid = indices < total_elems
    #vals = tl.load(input_ptr + indices, mask=valid, other=0.0)
    vals = tl.load(input_ptr + indices)
    
    
    #extract the non-zero lanes. 
    shifted = bitmap >> (63 - offsets)
    bit_mask = shifted & 1

    #slot index. 
    prefix = tl.cumsum(bit_mask, axis=0) - 1 

    valid_pos = bit_mask.to(tl.int1)

    #to recover index in terms of fp16s .
    store_idx = idx * 2 + prefix 

    #write out the non-zero values. 
    tl.store(packed_not_ptr + store_idx, vals, mask=valid_pos)

@triton.jit
def compress_kernel_batched(
    input_ptr,          # flattened [B * num_tiles_per_batch * 64]
    bitmaps_ptr,        # flattened [B * num_tiles_per_batch]
    counts_ptr,         # flattened [B * (num_tiles_per_batch + 1)]
    packed_not_ptr,     # flattened output buffer
    batch_offsets_ptr,  # flattened [B]
    total_elems: tl.constexpr,
    stride_batch: tl.constexpr,  # = num_tiles_per_batch * 64
    M: tl.constexpr,
    N: tl.constexpr
):

    tile_id = tl.program_id(0)
    batch_id = tl.program_id(1)

    
    #naive version
    #tile_offset = batch_id * stride_batch + tile_id * 64
    #offsets = tl.arange(0, 64)
    #indices = tile_offset + offsets

    #for key version: 
    # Calculate which 64x64 block we're in, accounting for batch
    block_row = tile_id % N
    block_col = tile_id // N
    base_idx = batch_id * stride_batch + block_row * M + block_col * 64
    offsets = tl.arange(0, 64)
    indices = base_idx + offsets

    #to address per-tile bitmap and counts. 
    flat_tile_index = batch_id * (stride_batch // 64) + tile_id
    
    bitmap = tl.load(bitmaps_ptr + flat_tile_index)
    # Get the base offset for this batch (because accum_count has num_tiles_per_batch + 1 per batch. )
    #idx = tl.load(counts_ptr + flat_tile_index)
    batch_base_idx = batch_id * ((stride_batch // 64) + 1)  # +1 for the extra count per batch
    idx = tl.load(counts_ptr + batch_base_idx + tile_id)
    
    #idx_plus_one = tl.load(counts_ptr + flat_tile_index + 1)
    #cnt = (idx_plus_one - idx) * 2  # number of float16s to store for this tile

    #valid = indices < total_elems
    #vals = tl.load(input_ptr + indices, mask=valid, other=0.0)
    vals = tl.load(input_ptr + indices)
    
    #extract the non-zero lanes. 
    shifted = bitmap >> (63 - offsets)
    bit_mask = (shifted & 1) #.to(tl.int32)

    #slot index. 
    prefix = tl.cumsum(bit_mask, axis=0) - 1
    valid_pos = bit_mask.to(tl.int1)

    batch_base_offset = tl.load(batch_offsets_ptr + batch_id)
    #store_idx = idx * 2 + prefix
    store_idx = batch_base_offset + idx * 2 + prefix
    tl.store(packed_not_ptr + store_idx, vals, mask=valid_pos)



def convert_tensor(input: torch.Tensor):
  
    if input.dim() != 2:
        raise ValueError("Input tensor must be 2-D")
    if not input.is_cuda:
        raise ValueError("Input tensor must be on the GPU")
    
    if (input.size(0) % 64 != 0):
        raise ValueError("The number of rows must be a multiple of 64") 
    
    M = input.size(0)
    N = input.size(1)
    
    input_flat = input.t().contiguous()#.view(-1, 64) #so this is the one that we have to change the ordering, to change the traversal ordering. 
        #this one traverses one column entirely. [0:64, 0], then [64:128,0] ... 
    
    #for [0:64, 0], [0:64, 1], ... [0:64, 63], then [64:128, 0], [64:128, 1], ... [64:128, 63], ... 

    #and later for Value compression: (Given that value matrix dimension is [T, D])
        #[0, 0:64], [1, 0:64], ...[63, 0:64], THEN, [0, 64:128], [1, 64:128], ... [63, 64:128], ... 

    
    total_elems = input_flat.numel()
    num_tiles = total_elems // 64
    
    bitmaps = torch.empty((num_tiles,), dtype=torch.int64, device=input.device)
    #will have to add one more for KV cache update. (done)
    counts = torch.empty((num_tiles,), dtype=torch.int32, device=input.device)
    
    #precomputed shifts from 63 to 0. (for later: possible room for optimization as this incurs CPU->GPU transfer)
    shift_amounts = np.arange(63, -1, -1, dtype=np.int64)
    shifts_np = np.left_shift(np.int64(1), shift_amounts)
    const_shifts = torch.tensor(shifts_np, device='cuda')

    #for single batch, distribute tile-wise to grid.x 
    grid = (num_tiles,)
    #print("GRID: ", grid)
    calculate_bitmap_kernel[grid](input_flat, bitmaps, counts, total_elems=total_elems, shifts_ptr = const_shifts, M=M, N=N)    
    #cummulative sum of counts to serve as tile access pointer. 
    accum_counts = torch.cumsum(counts, dim=0).to(torch.int32)
    accum_counts = torch.cat([torch.zeros(1, dtype=counts.dtype, device=counts.device), accum_counts])
    total = 2 * accum_counts[-1] #in 16bits. #using the padded cnt to allocate the memory for non-zeros. 
    packed_not = torch.zeros((total,), dtype=torch.float16, device=input.device) #non-zero padding directly applied here. 
    compress_kernel[grid](input_flat, bitmaps, accum_counts, packed_not, total_elems=total_elems, M=M, N=N)

    return bitmaps, accum_counts, packed_not

def convert_tensor_batched(inputs: torch.Tensor):
    B, M, N = inputs.shape
    assert inputs.is_cuda
    assert inputs.dim() == 3
    assert M % 64 == 0

    inputs_t = inputs.transpose(1, 2).contiguous()  # [B, N, M]
    #input_flat = inputs[:, :, :].transpose(1, 2).contiguous().view(-1)
    
    num_tiles_per_batch = (M * N) // 64
    #total_tiles = B * num_tiles_per_batch

    bitmaps = torch.empty((B, num_tiles_per_batch), dtype=torch.int64, device=inputs.device)
    counts  = torch.empty((B, num_tiles_per_batch), dtype=torch.int32, device=inputs.device)

    # Precomputed shifts
    shift_amounts = np.arange(63, -1, -1, dtype=np.int64)
    shifts_np = np.left_shift(np.int64(1), shift_amounts)
    const_shifts = torch.tensor(shifts_np, device='cuda')

    #grid = (B, num_tiles_per_batch)
    grid = (num_tiles_per_batch, B) # flip grid to escape tigher limit on y dim. 
    stride_batch = num_tiles_per_batch * 64

    
    # Debug prints
    #print(f"Input shape: {inputs.shape}")
    #print(f"Inputs_t shape: {inputs_t.shape}")
    #print(f"Num tiles per batch: {num_tiles_per_batch}")
    #print(f"Stride batch: {stride_batch}")
    
    #print("GRID: ", grid)

    calculate_bitmap_kernel_batched[grid](
        #input_flat,
        inputs_t.view(-1)   ,
        bitmaps.view(-1),
        counts.view(-1),
        total_elems=B * M * N,
        shifts_ptr=const_shifts,
        stride_batch=stride_batch,
        M=M,
        N=N
    )

    accum_counts = torch.cumsum(counts, dim=1).to(torch.int32)  # [B, T]
    accum_counts = torch.cat([
        torch.zeros((B, 1), dtype=counts.dtype, device=counts.device),
        accum_counts
    ], dim=1).contiguous()  # [B, T+1]

    #print("accum_counts debugged access: ", accum_counts.view(-1)[66])

    total = 2 * accum_counts[:, -1]  # [B], float16 count per batch
    offsets = torch.cumsum(total, dim=0) #per batch access offset. 
    batch_offsets = torch.cat([torch.zeros(1, dtype=torch.int32, device=inputs.device), offsets[:-1]])
    #print("accum_counts shape: ", accum_counts.shape) #[B, 65]
    #print('total: ', total)
    #print("offsets: ", offsets)
    total_packed_size = offsets[-1].item()
    packed_not_flat = torch.zeros((total_packed_size,), dtype=torch.float16, device=inputs.device) #non-zero padding directly applied here.

    #pass in tensor points as view(-1) for linear access, 
    compress_kernel_batched[grid](
        #input_flat,
        inputs_t.view(-1),
        bitmaps.view(-1),
        accum_counts.view(-1),
        packed_not_flat.view(-1),
        batch_offsets.view(-1),
        total_elems=B * M * N,
        stride_batch=stride_batch,
        M=M,
        N=N
    )
    #print("packed_not_flat: ", packed_not_flat)

    # Step 2: Slice `packed_not_flat` into per-batch tensors
    start_offsets = torch.zeros_like(offsets)
    start_offsets[1:] = offsets[:-1]
    #print("start_offsets: ", start_offsets)
    #return packed_not as a list of tensors, one per batch. 
    packed_not_batched = []
    for b in range(B):
        start = start_offsets[b].item()
        end = offsets[b].item()
        packed_not_batched.append(packed_not_flat[start:end].clone())

    #bitmaps and accoum_counts size is deterministic [B, num_tiles_per_batch]
    #packed_not_batched determined right above. 
    return bitmaps, accum_counts, packed_not_batched 
    
