import torch
import torch.nn.functional as F
import math
import logging

def get_attention_score(past_key_values, latest_captured_rope_queries, position_offset, compressKV_size, layers, head_idx, config, attention_mask=None):
    """
    Computes the accumulated attention scores and returns the indices of the top-k most important tokens.
    
    Args:
        past_key_values: The KV cache object.
        latest_captured_rope_queries: Query tensors [Step, Layer, Batch, Heads, Dim].
        position_offset: Current sequence length (limit for K).
        compressKV_size (int): The number of top indices to return (k).
        layers: List of layer indices.
        head_idx: Dict or list of head indices.
    
    Returns:
        torch.Tensor: Top-k indices of shape [compressKV_size].
    """
    GQA_group_size = config.num_attention_heads // config.num_key_value_heads
    first_device = past_key_values.key_cache[0].device
    
    # Pre-compute scale factor once
    head_dim = latest_captured_rope_queries.shape[-1]
    scale = head_dim ** -0.5
    
    accumulated_scores = None

    for filtered_layer_idx, layer_idx in enumerate(layers):
        # Get the device of current layer's K cache
        layer_device = past_key_values.key_cache[layer_idx].device
        
        # 1. Extract Q
        # Q Shape: [Step, Batch, filtered_Heads, Head_Dim]
        Q = latest_captured_rope_queries[:, filtered_layer_idx]
        
        selected_Q_indices = torch.as_tensor(head_idx[filtered_layer_idx], device=layer_device)
        K_map_indices = selected_Q_indices // GQA_group_size
        
        # K Shape: [Batch, filtered_Heads, Seq_Len, Head_Dim]
        K = past_key_values.key_cache[layer_idx][:, K_map_indices, :position_offset, :]

        # 2. Recompute attention scores (Q * K^T)
        # A. Adjust Q dimensions: [Batch, Heads, Step, Dim] and move to K's device
        Q_permuted = Q.permute(1, 2, 0, 3).to(layer_device)

        # B. Matrix multiplication: [Batch, Heads, Step, Seq_Len]
        attn_scores = torch.matmul(Q_permuted, K.transpose(-1, -2))
        
        # C. Scaling (inplace to save memory)
        attn_scores.mul_(scale)
        
        # D. Apply attention mask if provided
        if attention_mask is not None:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(layer_device)
            attn_scores.masked_fill_(extended_attention_mask == 0, torch.finfo(attn_scores.dtype).min)
            
        # E. Softmax
        attn_scores = F.softmax(attn_scores, dim=-1)
        
        # F. Accumulate scores across layers
        if accumulated_scores is None:
            accumulated_scores = attn_scores
        else:
            # Move to same device if needed, then add inplace
            if accumulated_scores.device != attn_scores.device:
                attn_scores = attn_scores.to(accumulated_scores.device)
            accumulated_scores.add_(attn_scores)

    # 3. Sum over Heads and Step dimensions
    if accumulated_scores is None:
        return torch.tensor([], device=first_device, dtype=torch.long)

    # Sum over Heads and Step: [batch_size, Seq_Len]
    avg_scores = accumulated_scores.sum(dim=(1, 2))

    return avg_scores

def check_dup(ts):
    """
    ts: 2D torch tensor of shape (N, M)
    Checks duplicate values within each row.
    """
    assert ts.dim() == 2, "Input must be a 2D tensor"

    for row_idx, row in enumerate(ts):
        uniq, counts = torch.unique(row, return_counts=True)

        mask = counts > 1
        dup_vals = uniq[mask]
        dup_cnts = counts[mask]

        if dup_vals.numel() > 0:
            print(f"Row {row_idx}: #duplicate values = {dup_vals.numel()}")
            for v, c in zip(dup_vals.tolist(), dup_cnts.tolist()):
                print(f"  value = {v:.10f}, count = {c}")

@torch.no_grad()
def compressKV_mb_sub1(full_pos_ids, window_size, sink_size, attention_mask, draft_attention_score, target_attention_score, compressKV_size, compressKV_draft_select_size):
    # Get batch size and device from available sources
    if target_attention_score is not None:
        B = target_attention_score.shape[0]
        device = target_attention_score.device
        dtype = target_attention_score.dtype
    elif draft_attention_score is not None:
        B = draft_attention_score.shape[0]
        device = draft_attention_score.device
        dtype = draft_attention_score.dtype
    else:
        # Both scores are None, use attention_mask
        B = attention_mask.shape[0]
        device = attention_mask.device
        dtype = torch.float32  # Default dtype
    min_val = torch.finfo(dtype).min

    # 1. window indices
    window_start = full_pos_ids[:, -1] - window_size
    window_mask = full_pos_ids > torch.clamp(window_start.unsqueeze(1), min=-1)
    near_mask = window_mask & attention_mask
    near_indices = near_mask.nonzero(as_tuple=False)[:, 1].view(B, window_size)

    # 2. sink indices
    token_count_mask = attention_mask.cumsum(dim=1)
    sink_mask = attention_mask & (token_count_mask <= sink_size)
    sink_indices = sink_mask.nonzero(as_tuple=False)[:, 1].view(B, sink_size)

    # 3. core indices
    k_core = compressKV_size - compressKV_draft_select_size
    draft_k = compressKV_draft_select_size - window_size - sink_size
    
    # Create combined exclusion mask once
    exclusion_mask = window_mask | sink_mask
    
    # When draft_k == 0, we don't need draft attention score
    if draft_k > 0 and draft_attention_score is not None:
        masked_draft_score = torch.where(exclusion_mask, min_val, draft_attention_score)
        
        if target_attention_score is not None:
            masked_target_score = torch.where(exclusion_mask, min_val, target_attention_score)
            
            # Get draft topk indices
            _, draft_topk_indices = torch.topk(masked_draft_score, draft_k, dim=1)
            
            # Mask out draft selected positions from target score
            masked_target_score.scatter_(dim=1, index=draft_topk_indices, value=min_val)
            
            _, target_topk_indices = torch.topk(masked_target_score, k_core, dim=1)
            
            keep_core = torch.cat([target_topk_indices, draft_topk_indices], dim=1)
        else:
            _, keep_core = torch.topk(masked_draft_score, draft_k, dim=1)
    else:
        # draft_k == 0 or draft_attention_score is None, only use target attention score
        if target_attention_score is not None:
            masked_target_score = torch.where(exclusion_mask, min_val, target_attention_score)
            if k_core > 0:
                _, keep_core = torch.topk(masked_target_score, k_core, dim=1)
            else:
                keep_core = torch.empty((B, 0), dtype=torch.long, device=device)
        else:
            # No scores available: raise error
            raise ValueError("Both draft_attention_score and target_attention_score are None.")

    # Sort and concatenate
    if keep_core.numel() > 0:
        keep_core = keep_core.sort(dim=1).values
        keep_indices_per_layer = torch.cat((sink_indices, keep_core, near_indices), dim=-1)
    else:
        keep_indices_per_layer = torch.cat((sink_indices, near_indices), dim=-1)
    
    return keep_indices_per_layer

@torch.no_grad()
def compressKV_mb(src_key_values, position_offset, window_size, sink_size, compressKV_size, compressKV_draft_select_size, target_key_values, draft_model, full_pos_ids, attention_mask, target_model=None, enable_analysis=False):
    # Supports multi-batch:
    # attention_score: [B, Seq_Len]
    # Compute attention score using draft model's queries
    if enable_analysis:
        # Collect all devices that may be involved in computation
        devices = set()
        # Draft model KV cache devices
        for kv in draft_model.past_key_values.key_cache:
            if kv is not None:
                devices.add(kv.device)
        # Target model KV cache devices (if exists)
        if src_key_values is not None:
            for kv in src_key_values.key_cache:
                if kv is not None:
                    devices.add(kv.device)
        
        # Synchronize all devices before starting timing
        for device in devices:
            torch.cuda.synchronize(device)
        
        # Use the first CUDA device for event timing
        timing_device = next((d for d in devices if d.type == 'cuda'), torch.device('cuda:0'))
        ck_start = torch.cuda.Event(enable_timing=True)
        ck_end = torch.cuda.Event(enable_timing=True)
        with torch.cuda.device(timing_device):
            ck_start.record()
    
    # Check if draft attention score is needed
    draft_k = compressKV_draft_select_size - window_size - sink_size
    need_draft_score = draft_k > 0
    
    # Compute draft attention score only if needed
    attention_score = None

    if need_draft_score:
        attention_score = get_attention_score(
            draft_model.past_key_values,
            draft_model.latest_captured_rope_queries,
            position_offset,
            compressKV_draft_select_size,
            draft_model.important_layers,
            draft_model.important_heads,
            draft_model.config,
            attention_mask=attention_mask,
        )
        # Expect [B, S]
        if attention_score.dim() == 1:
            attention_score = attention_score.unsqueeze(0)

    # If target latest queries are provided, compute its attention score and combine
    target_attention_score = None
    need_target_score = target_model is not None and compressKV_size - compressKV_draft_select_size > 0
    if need_target_score:
        target_attention_score = get_attention_score(
            src_key_values,
            target_model.latest_captured_rope_queries,
            position_offset,
            compressKV_size,
            target_model.important_layers,
            target_model.important_heads,
            target_model.config,
            attention_mask=attention_mask,
        )
        if target_attention_score.dim() == 1:
            target_attention_score = target_attention_score.unsqueeze(0)

    # Determine batch size and device
    if attention_score is not None:
        B, _ = attention_score.shape
        device = attention_score.device
    elif target_attention_score is not None:
        B, _ = target_attention_score.shape
        device = target_attention_score.device
    else:
        # Fallback to draft model's device
        B = attention_mask.shape[0]
        device = draft_model.past_key_values.key_cache[0].device
    
    # Ensure all tensors are on the same device
    attention_mask = attention_mask.to(device)
    full_pos_ids = full_pos_ids.to(device)
    if attention_score is not None:
        attention_score = attention_score.to(device)
    if target_attention_score is not None:
        target_attention_score = target_attention_score.to(device)
    
    attention_mask_sum = attention_mask.sum(dim=1)
    sub1_mask = attention_mask_sum >= compressKV_size
    sub2_mask = attention_mask_sum < compressKV_size

    keep_indices_per_layer = None
    ret_attn_mask = None

    if sub1_mask.any():
        keep_indices_per_layer = torch.empty((B, compressKV_size), dtype=torch.long, device=device)
        ret_attn_mask = torch.empty((B, compressKV_size), dtype=torch.bool, device=device)

        sub1 = compressKV_mb_sub1(
            full_pos_ids[sub1_mask], 
            window_size, 
            sink_size,
            attention_mask[sub1_mask], 
            attention_score[sub1_mask] if attention_score is not None else None, 
            target_attention_score[sub1_mask] if target_attention_score is not None else None, 
            compressKV_size, 
            compressKV_draft_select_size
        )

        keep_indices_per_layer[sub1_mask] = sub1
        ret_attn_mask[sub1_mask] = True

    if sub2_mask.any():
        max_len = attention_mask_sum.max().int().item()
        if keep_indices_per_layer is None:
            keep_indices_per_layer = torch.empty((B, max_len), dtype=torch.long, device=device)
            ret_attn_mask = torch.empty((B, max_len), dtype=torch.bool, device=device)
        out_size = keep_indices_per_layer.size(1)
        for b in range(B):
            if sub2_mask[b]:
                valid_mask = attention_mask[b]
                keep_i = torch.arange(attention_mask.shape[1], device=device)[valid_mask]
                seq_len = keep_i.size(0)
                pad_size = out_size - seq_len
                if pad_size > 0:
                    keep_indices_per_layer[b, :pad_size] = 0
                    keep_indices_per_layer[b, pad_size:] = keep_i
                    ret_attn_mask[b, :pad_size] = False
                    ret_attn_mask[b, pad_size:] = True
                else:
                    keep_indices_per_layer[b] = keep_i[:out_size]
                    ret_attn_mask[b] = True

    if enable_analysis:
        # Synchronize all devices to ensure all operations are complete
        for device in devices:
            torch.cuda.synchronize(device)
        with torch.cuda.device(timing_device):
            ck_end.record()
            ck_end.synchronize()
        criticality_estimation = ck_start.elapsed_time(ck_end)
        logging.debug(f"Criticality Estimation Latency: {criticality_estimation:.3f} ms")
    else:
        criticality_estimation = 0.0
    set_compressKV(src_key_values, keep_indices_per_layer, target_key_values)
    return keep_indices_per_layer, ret_attn_mask, criticality_estimation

# This is deprecated
# @torch.no_grad()
# def compressKV(src_key_values, position_offset, window_size, compressKV_size, target_key_values, draft_model):
#     # Supports multi-batch:
#     # attention_score: [B, Seq_Len]
#     attention_score = get_attention_score(
#         draft_model.past_key_values,
#         draft_model.latest_captured_rope_queries,
#         position_offset,
#         compressKV_size,
#         draft_model.important_layers,
#         draft_model.important_heads,
#         draft_model.config
#     )
#     # Expect [B, S]
#     if attention_score.dim() == 1:
#         attention_score = attention_score.unsqueeze(0)

#     B, _ = attention_score.shape
#     device = src_key_values.key_cache[0].device

#     eff_window = min(window_size, position_offset)
#     eff_compress = min(compressKV_size, position_offset)

#     k = max(0, eff_compress - eff_window)
#     prefix_len = position_offset - eff_window

#     # select topk kv (keep_core) per batch
#     if k > 0 and prefix_len > 0:
#         topk_zone = attention_score[:, :prefix_len]                 # [B, prefix_len]
#         k = min(k, topk_zone.shape[-1])
#         keep_core = torch.topk(topk_zone, k, dim=-1).indices        # [B, k]
#         keep_core = keep_core.sort(dim=-1).values                   # [B, k] ascending indices
#     else:
#         keep_core = torch.empty((B, 0), dtype=torch.long, device=device)  # [B, 0]

#     # keep eff_window kv (near_indices), shared across batch then expanded
#     base_near = torch.arange(prefix_len, position_offset, device=device, dtype=torch.long)  # [eff_window]
#     if base_near.numel() == 0:
#         near_indices = torch.empty((B, 0), dtype=torch.long, device=device)                 # [B, 0]
#     else:
#         near_indices = base_near.unsqueeze(0).expand(B, -1).contiguous()                    # [B, eff_window]

#     # merge keep_core & near_indices
#     # keep_indices_per_layer: [B, k + eff_window]
#     keep_indices_per_layer = torch.cat((keep_core, near_indices), dim=-1)

#     set_compressKV(src_key_values, keep_indices_per_layer, target_key_values)
#     return keep_indices_per_layer

@torch.no_grad()
def set_compressKV(src_key_values, keep_indices, target_key_values):
    # Supports multi-batch:
    # keep_indices can be:
    #   - Tensor [B, K] (preferred)
    #   - Tensor [K] (shared indices for all batches)
    #   - dict[layer_id] -> Tensor [B, K] or [K]
    if src_key_values is None:
        raise ValueError("src_key_values is None")

    keep_len = 0
    idx_map = keep_indices

    src_K = src_key_values.key_cache
    src_V = src_key_values.value_cache
    dst_K = target_key_values.key_cache
    dst_V = target_key_values.value_cache

    L = len(src_K)
    assert L == len(dst_K) == len(src_V) == len(dst_V), "KV layer count mismatch"

    for l in range(L):
        K_src, V_src = src_K[l], src_V[l]
        K_dst, V_dst = dst_K[l], dst_V[l]

        if K_src is None or V_src is None:
            continue  # layer not initialized yet

        # Choose indices for this layer
        idx = idx_map[l] if isinstance(idx_map, dict) else idx_map
        idx = idx.to(K_src.device)

        # Sanity checks on KV shapes
        if K_src.dim() != 4 or K_dst.dim() != 4:
            raise ValueError(
                f"Unexpected KV dims at layer {l}: src={tuple(K_src.shape)} dst={tuple(K_dst.shape)}"
            )

        Bs, Hs, Ss, Ds = K_src.shape
        Bd, Hd, Sd, Dd = K_dst.shape
        if (Bs, Hs, Ds) != (Bd, Hd, Dd):
            raise ValueError(
                f"Batch/heads/dim mismatch at layer {l}: src={tuple(K_src.shape)} dst={tuple(K_dst.shape)}"
            )

        # Normalize idx to [B, K]
        if idx.numel() == 0:
            K_dst.zero_()
            V_dst.zero_()
            keep_len = 0
            continue

        if idx.dim() == 1:
            # shared indices across batch -> [B, K]
            idx = idx.unsqueeze(0).expand(Bs, -1).contiguous()
        elif idx.dim() == 2:
            if idx.size(0) != Bs:
                raise ValueError(
                    f"keep_indices batch mismatch at layer {l}: idx batch={idx.size(0)} vs KV batch={Bs}"
                )
        else:
            raise ValueError(f"keep_indices must be 1D or 2D (or dict thereof); got idx.dim()={idx.dim()} at layer {l}")

        # Range checks (per-batch safe)
        idx_min = idx.min().item()
        idx_max = idx.max().item()
        if idx_max >= Ss or idx_min < 0:
            raise IndexError(f"Index out of range at layer {l}: min={idx_min}, max={idx_max}, src seq={Ss}")

        K = idx.size(1)
        keep_len = K

        if keep_len > Sd:
            raise ValueError(f"Destination capacity too small at layer {l}: keep_len={keep_len}, Sd={Sd}")

        # Zero destination first
        K_dst.zero_()
        V_dst.zero_()

        idx_exp = idx[:, None, :, None].expand(Bs, Hs, K, Ds)  # [B, H, K, D]
        K_sel = torch.gather(K_src, dim=2, index=idx_exp)
        V_sel = torch.gather(V_src, dim=2, index=idx_exp)
        K_dst[:, :, :keep_len, :].copy_(K_sel)
        V_dst[:, :, :keep_len, :].copy_(V_sel)

    target_key_values.crop(keep_len)      # if dynamic, adjust size; if static, no-op
    target_key_values.seq_len = keep_len  # update seq_len for both static/dynamic