"""
自适应Attention Mask生成算法 - 函数式实现

参考 attn_mask.py 的实现模式，将 attn_mask1.py 重构为函数形式
集成 FlashInfer 和 SpargeSageAttn backend 用于加速

该模块实现基于Warm up阶段学习的自适应attention mask生成，包括：
1. Warm up阶段：提取attention map的结构特征（对角线、垂直线、分块对角）
2. 预测阶段：基于学习到的特征动态生成sparse mask
3. Backend加速：支持 FlashInfer 和 SpargeSageAttn 两种加速方式
"""

import torch
import time
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
import numpy as np
from einops import rearrange, repeat
from sparse_sageattn import sparse_sageattn

# from SpargeAttn.spas_sage_attn.autotune import SparseAttentionMeansim
from SVPG.svpg.get_radial_mask import get_radial_mask
try:
    from SVPG.svpg.cuda_lstsq import solve_lstsq
except ImportError:
    print("Warning: cuda_lstsq not available, using PyTorch fallback")
    solve_lstsq = None



# ============================================================================
# Global Workspace for FlashInfer (to avoid per-layer allocation)
# ============================================================================
_GLOBAL_FLASHINFER_WORKSPACE: Optional[torch.Tensor] = None
# ============================================================================

# ============================================================================
# Utility Functions from attn_mask.py
# ============================================================================
def reset_peak_gpu_stats(device: Optional[torch.device] = None):
    if torch.cuda.is_available():
        if device is None:
            torch.cuda.reset_peak_memory_stats()
        else:
            torch.cuda.reset_peak_memory_stats(device.index)

def print_peak_gpu_stats(prefix: str = "", device: Optional[torch.device] = None):
    if not torch.cuda.is_available():
        print(prefix + " GPU not available")
        return
    # 确保所有 kernel 完成
    torch.cuda.synchronize()
    if device is None:
        alloc = torch.cuda.max_memory_allocated()
        reserved = torch.cuda.max_memory_reserved()
    else:
        alloc = torch.cuda.max_memory_allocated(device.index)
        reserved = torch.cuda.max_memory_reserved(device.index)
    def to_gib(x): return x / (1024**3)
    print(f"{prefix} peak_alloc={to_gib(alloc):.3f} GiB, peak_reserved={to_gib(reserved):.3f} GiB")


# ============================================================================
# Feature Extraction Functions (from AttentionMapDecomposer)
# ============================================================================

def extract_attention_features(
    warmup_state: Dict,
    S_T: torch.Tensor,
    S_0: torch.Tensor,
    blocks_per_frame: int,
    regularization: float = 1,
    use_cuda: bool = True
) -> Dict[str, torch.Tensor]:
    """
    从attention map中提取特征
    
    参数:
        S_T: 当前步的attention map, shape (n, n)
        S_0: 第0步的attention map, shape (n, n)
        block_configs: 块配置列表，每个元素为(start_idx, size)
        regularization: 正则化系数
        use_cuda: 是否使用CUDA加速
    
    返回:
        特征字典 {
            'p': 标量, S_0的权重
            'c': Tensor，对角线亮度值，shape (2n-1,)
            'd': Tensor，垂直线亮度值，shape (n,)
            'e': Tensor，块亮度值，shape (num_blocks,)
        }
    """
    num_heads = S_T.shape[0]
    n = S_T.shape[1]
    assert S_T.shape == (num_heads, n, n), f"S_T shape mismatch: {S_T.shape} vs ({num_heads}, {n}, {n})"
    assert S_0.shape == (num_heads, n, n), f"S_0 shape mismatch: {S_0.shape} vs ({num_heads}, {n}, {n})"


    # 调用GPU加速的最小二乘求解（如果可用）
    features = solve_lstsq(
        warmup_state=warmup_state,
        S_T=S_T,
        S_0=S_0,
        step=warmup_state['current_steps'],
        blocks_per_frame=blocks_per_frame,
        regularization=regularization,
        use_cuda=use_cuda and S_T.is_cuda
    )
    return features



# ============================================================================
# Warmup Management Functions
# ============================================================================

def init_warmup_state(num_heads: int, model_type: str, warmup_steps: int = 12,) -> Dict:
    """
    初始化warmup状态
    
    参数:
        num_heads: attention head的数量
        warmup_steps: warmup的步数
    
    返回:
        warmup状态字典
    """
    return {
        'features_history': None,  # {head_idx: List[Dict]}
        'S_0_maps': None,  # {head_idx: Tensor}
        'current_steps': 0,  # {head_idx: int}
        'prev_full_maps': None,  # {head_idx: Tensor}
        'prediction_features': {},  # {head_idx: Dict}
        'MTM': None,
        'mask': None,
        'sparse_id': None,
        'warmup_steps': warmup_steps,
        'model_type': model_type,
        'num_heads': num_heads
    }

def create_adaptive_mask_state(
    # ... (参数不变, 移除 shared_workspace_buffer 参数) ...
    model_type: str,
    num_heads: int,
    video_token_num: int,
    text_token_num: int,
    tokens_per_frame: int,
    warmup_steps: int = 12,
    top_k: int = 10,
    predict_T: int = 5,
    thereshold: float = 1.5e-4,
    block_size: int = 128,
    predict_all: bool = True,
    use_block_attention: bool = False,
    use_cuda: bool = True
) -> Dict:
    """
    创建自适应mask状态的便捷函数
    (已修改：使用全局变量共享 workspace_buffer)
    """
    device = torch.cuda.current_device() if use_cuda else torch.device("cpu")

    state = init_warmup_state(num_heads, model_type, warmup_steps)
    # ... (设置 state['video_token_num'] 等...  [cite: 1, 927-933])
    state['video_token_num'] = video_token_num
    state['top_k'] = top_k
    state['thereshold'] = torch.tensor(thereshold).to(device) # 阈值
    state['sparge_attention'] = SparseAttentionMeansim(l1=0.07, pv_l1=0.08, tune_pv=True)
    state['text_token_num'] = text_token_num
    state['block_size'] = block_size
    state['predict_T'] = predict_T
    state['blocks_per_frame'] = tokens_per_frame // block_size
    state['use_block_attention'] = use_block_attention
    state['use_cuda'] = use_cuda
    state['predict_all'] = predict_all

    return state

def update_warmup(
    warmup_state: Dict,
    step: int,
    block_sparse_map: torch.Tensor,
):
    """
    更新warmup阶段的数据
    
    参数:
        warmup_state: warmup状态字典
        head_idx: head索引
        step: 当前步数

    """
    if step == warmup_state['warmup_steps'] - 2:
        warmup_state['features_history'] = []
        warmup_state['current_steps'] = step
        warmup_state['S_0_maps'] = block_sparse_map
        warmup_state['prev_full_maps'] = block_sparse_map
    # 提取特征
    S_T = block_sparse_map
    features = extract_attention_features(
        warmup_state=warmup_state,
        S_T=S_T,
        S_0=warmup_state['S_0_maps'],
        blocks_per_frame=warmup_state['blocks_per_frame'],
        use_cuda=warmup_state['use_cuda']
    )
    # 存储特征
    warmup_state['features_history'].append(features)
    warmup_state['current_steps'] = step + 1
    warmup_state['prev_full_maps'] = block_sparse_map


def is_warmup_complete(warmup_state: Dict) -> bool:
    """检查warmup是否完成"""
    return warmup_state['current_steps'] >= warmup_state['warmup_steps']

import flashinfer
def AttentionSparseEngine(query, key, value, mask, pre_defined_mask=None,video_token_num=0,block_size=128):
    batch_size = query.shape[0]
    converted_mask = torch.repeat_interleave(mask, 2, dim=-1).unsqueeze(0)

    if pre_defined_mask is None:
        output = sparse_sageattn(query, key, value, mask_id=mask)

        return output

    kv_border = (pre_defined_mask[0].sum() + 63) // 64
    converted_mask[:, :, :, kv_border:] = False
    output_video = sparse_sageattn(
        query[:, :, :video_token_num, :],
        key,
        value,
        mask_id=converted_mask[:, :, :video_token_num // block_size, :].contiguous(),
    )

    # flashinfer needs (seq, heads, dim), reshape from [batch, seq, heads, dim]
    q_flashinfer = rearrange(query[:, :, video_token_num:, :], "b h s d -> (b s) h d")
    k_flashinfer = rearrange(key[:, :, :pre_defined_mask[0].sum(), :], "b h s d -> (b s) h d")
    v_flashinfer = rearrange(value[:, :, :pre_defined_mask[0].sum(), :], "b h s d -> (b s) h d")
    output_text = flashinfer.single_prefill_with_kv_cache(
        q=q_flashinfer,
        k=k_flashinfer,
        v=v_flashinfer,
        causal=False,
        return_lse=False,
    )
    output_text = rearrange(output_text, "(b s) h d -> b h s d", b=batch_size)
    return torch.cat([output_video, output_text], dim=2)
# ============================================================================
# Mask Prediction Functions
# ============================================================================

def check_bd_values(values: torch.Tensor, mode: List[str], thereshold: float = -7.5e-2) -> bool:
    """
    检查block diagonal的值是否满足阈值条件

    参数:
        values: block diagonal值序列, shape (num_heads, warmup_steps)
        threshold: 阈值
        mode: p值模式
    """
    for i in range(values.shape[0]):
        if mode[i] == 'sparse':
            if values[i, -1] <= thereshold:
                mode[i] = 'block_diagonal'
    return mode

def choose_topk_lines(history: List[Dict[str, torch.Tensor]], top_k: int = 10, predict_all: bool = False, predict_T:int = None) -> torch.Tensor:
    last_features = history[-2]
    current_features = history[-1]
    last_c_values = last_features['c']
    current_c_values = current_features['c']
    last_d_values = last_features['d']
    current_d_values = current_features['d']

    if predict_all == True and predict_T is not None:
        predict_c_values_k = (current_c_values - last_c_values)
        predict_d_values_k = (current_d_values - last_d_values)
        predict_times = int(50 / predict_T)
        selected_lines = []
        for t in range(1, predict_times + 1):
            predict_c_values = current_c_values + predict_c_values_k * t * predict_T
            predict_d_values = current_d_values + predict_d_values_k * t * predict_T
            predict_c_d_values = torch.cat([predict_c_values, predict_d_values], dim=1)
            top_k_values, selected_lines_t = torch.topk(
                predict_c_d_values,
                k=min(top_k, predict_c_d_values.shape[1]),
                dim=1,
                largest=False,
                sorted=False
            )
            selected_lines.append(selected_lines_t)
    else:
        predict_c_values_k = (current_c_values - last_c_values) / predict_T
        predict_d_values_k = (current_d_values - last_d_values) / predict_T
        predict_c_values = current_c_values + predict_c_values_k
        predict_d_values = current_d_values + predict_d_values_k
        predict_c_d_values = torch.cat([predict_c_values, predict_d_values], dim=1)
        top_k_values, selected_lines = torch.topk(
            predict_c_d_values,
            k=min(top_k, predict_c_d_values.shape[1]),
            dim=1,
            largest=False,
            sorted=False
        )
    return selected_lines

def predict_mask_from_warmup(
    warmup_state: Dict,
    n: int,
    top_k: int = 10,
    layer_idx: int = 0,
) -> Dict:
    """
    基于warmup历史预测mask
    
    参数:
        warmup_state: warmup状态
        head_idx: head索引
        n: attention map尺寸
        top_k: 选择最亮的K条线
        block_configs: 块配置
    
    返回:
        mask信息字典
    """

    num_heads = warmup_state['num_heads']
    blocks_per_frame = warmup_state['blocks_per_frame']
    if not is_warmup_complete(warmup_state):
        device = list(warmup_state['S_0_maps'].values())[0].device if warmup_state['S_0_maps'] else torch.device('cpu')
        return {
            'mode': ['full'] * num_heads,
            'block_diagonal': [False] * num_heads,
            'selected_lines': [],
            'mask': torch.ones(num_heads, n, n, dtype=torch.bool, device=device)
        }
    
    # 获取特征历史
    history = warmup_state['features_history']
    device = history[0]['c'].device
    # 检查是否有分块对角模式
    mode = ['sparse'] * num_heads
    b_d_values = torch.stack([f['b_d'] for f in history],dim=0).to(device).permute(1, 0)
    mode = check_bd_values(b_d_values, thereshold=1e-2, mode=mode)
    # 选择最暗的K条线
    selected_lines = choose_topk_lines(history, top_k=top_k, 
    predict_all=warmup_state['predict_all'], predict_T=warmup_state.get('predict_T', None))

    # 生成mask
    if warmup_state['predict_all']:
        mask = []
        for t in range(len(selected_lines)):
            selected_lines_t = selected_lines[t]
            mask_t = generate_mask_from_lines(warmup_state, num_heads, n, blocks_per_frame, selected_lines_t, mode, device)
            mask.append(mask_t)
    else:
        mask = generate_mask_from_lines(warmup_state, num_heads, n, blocks_per_frame, selected_lines, mode, device)
    
    return {
        'mode': mode,
        'selected_lines': selected_lines,
        'mask': mask
    }



def generate_mask_from_lines(
    warmup_state: Dict,
    num_heads: int,
    n: int,
    blocks_per_frame: int,
    selected_lines: torch.Tensor,
    mode: List[str],
    device: torch.device
) -> torch.Tensor:
    """
    从选中的线条生成mask
    
    参数:
        n: mask尺寸
        selected_lines: 选中的线条列表
        mode: 每个head的模式列表
        device: 设备
    
    返回:
        boolean mask
    """
    block_size = warmup_state['block_size']
    block_num = n // block_size
    text_token_num = warmup_state['text_token_num']
    text_block_num = text_token_num // block_size
    video_block_num = block_num - text_block_num

    # 结果 mask（初始全 True 如原实现）
    mask = torch.ones(num_heads, block_num, block_num, dtype=torch.bool, device=device)

    # video_mask 要在 video 区域填充
    video_mask = torch.zeros(num_heads, video_block_num, video_block_num, dtype=torch.bool, device=device)

    # 规范 selected_lines 为 tensor (num_heads, k)
    if not torch.is_tensor(selected_lines):
        selected_lines = torch.as_tensor(selected_lines, device=device)
    selected_lines = selected_lines.to(device=device)
    if selected_lines.dim() == 1:
        selected_lines = selected_lines.unsqueeze(1)
    selected_lines = selected_lines.long()  # (num_heads, k)

    # head 模式布尔向量
    mode_list = list(mode)
    full_heads = torch.tensor([m == 'full' for m in mode_list], device=device, dtype=torch.bool)
    block_diag_heads = torch.tensor([m == 'block_diagonal' for m in mode_list], device=device, dtype=torch.bool)

    # 直接对 full heads 设置 True
    if full_heads.any():
        video_mask[full_heads] = True

    # 预计算行列索引网格用于对角线 / block diag 判定
    if video_block_num > 0:
        row_idx = torch.arange(video_block_num, device=device).view(1, video_block_num, 1)   # (1, R, 1)
        col_idx = torch.arange(video_block_num, device=device).view(1, 1, video_block_num)   # (1, 1, R)
        diff = col_idx - row_idx  # (1, R, R)  col - row

        # 处理每个选中线条（top_k 通常较小，循环 top_k 但内部全张量化）
        k = selected_lines.size(1)
        for t in range(k):
            idx_k = selected_lines[:, t]  # (num_heads,)

            # 跳过已为 full 的 heads（已全 True）
            active_mask = ~full_heads
            if not active_mask.any():
                break

            idx_k = idx_k.to(device)

            # # case A: idx 指向列（ idx < video_block_num ）
            col_case = (idx_k < video_block_num) & active_mask

            if col_case.any():
                # 构造列掩码并按 head 广播
                # col_eq: (num_heads, R) 表示每 head 哪一列需要置真
                col_eq = (torch.arange(video_block_num, device=device).view(1, video_block_num) == idx_k.view(num_heads, 1))
                col_eq = col_eq & col_case.view(num_heads, 1)
                # expand 到 (num_heads, R, R) 表示整列为 True
                col_mask_2d = col_eq.unsqueeze(1).expand(-1, video_block_num, -1)
                video_mask |= col_mask_2d

            # case B: idx 指向对角线类（ idx >= video_block_num ）
            diag_case = (idx_k >= video_block_num) & active_mask
            if diag_case.any():
                offsets = idx_k - (video_block_num - 1) - video_block_num  # (num_heads,)
                # 比较 diff == offsets[head]，先扩展 offsets
                offsets_exp = offsets.view(num_heads, 1, 1)
                diff_exp = diff.expand(num_heads, video_block_num, video_block_num)  # (num_heads, R, R)
                diag_mask = (diff_exp == offsets_exp) & diag_case.view(num_heads, 1, 1)
                video_mask |= diag_mask

        # block_diagonal：按 blocks_per_frame 将相同 block_index 的方块置为 True
        if block_diag_heads.any():
            # 计算 block index 网格 (R, R)
            row_blk = (torch.arange(video_block_num, device=device) // blocks_per_frame).view(video_block_num, 1)
            col_blk = (torch.arange(video_block_num, device=device) // blocks_per_frame).view(1, video_block_num)
            block_diag_2d = (row_blk == col_blk)  # (R, R)
            # 按 head 应用
            video_mask |= block_diag_heads.view(num_heads, 1, 1) & block_diag_2d.unsqueeze(0)

    # 合并回全 mask（遵循原逻辑：cogvideox / hunyuan 两种布局）
    if warmup_state.get('model_type') == 'cogvideox':
        # text 区域全 True；video 区域用 video_mask
        if text_block_num > 0:
            mask[:, :text_block_num, :] = True
            mask[:, :, :text_block_num] = True
        mask[:, text_block_num:, text_block_num:] = video_mask
    elif warmup_state.get('model_type') == 'hunyuan':
        # video 区域全 True；text 区域用 video_mask
        mask[:, video_block_num:, :] = True
        mask[:, :, video_block_num:] = True
        mask[:, :video_block_num, :video_block_num] = video_mask
    else:
        # 若未知 model_type，保守返回全 True
        mask = torch.ones_like(mask)

    return mask


# ============================================================================
# Main Function: SparseAttentionWithMap
# ============================================================================


def SparseAttentionWithMap(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    warmup_state: Dict,
    pre_defined_mask: Optional[torch.Tensor] = None,
    head_idx: int = 0,
    current_step: int = 0,
    top_k: int = 10,
    use_cuda: bool = True,
    backend: str = "flashinfer",
    block_size: int = 128,
    layer_idx: int = 0,
    return_attention_map: bool = False,
    **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
    """
    执行自适应稀疏attention并返回结果
    
    主要对外调用接口，集成了FlashInfer和SpargeSageAttn两种backend
    
    参数:
        query: query tensor, shape (seq_len, head_dim)
        key: key tensor, shape (seq_len, head_dim)
        value: value tensor, shape (seq_len, head_dim)
        warmup_state: warmup状态字典（通过init_warmup_state创建）
        head_idx: head索引
        current_step: 当前步数
        block_configs: 块配置列表
        top_k: 选择最亮的K条线
        use_cuda: 是否使用CUDA加速
        backend: "flashinfer" 或 "sparse_sageattn"
        block_size: block大小（用于backend）
        return_attention_map: 是否返回attention map
        **kwargs: 其他参数
    
    返回:
        (output, attention_map, mask_info)
    """
    seq_len = query.size(2)
    n = seq_len
    num_heads = query.size(1)
    block_num = n // block_size
    # 判断是否在warmup阶段
    in_warmup = current_step < warmup_state['warmup_steps']
    if in_warmup:
        query = query.contiguous()
        key = key.contiguous()
        value = value.contiguous()
        if current_step >= warmup_state['warmup_steps'] - 2:
            if warmup_state['use_block_attention']:
                output, block_sparse_map = compute_block_sparse_map(query, key, value, warmup_state)
            else:
                with torch.no_grad():
                    output, attn_map = compute_full_attention(query, key, value)    
                text_token_num = warmup_state.get('text_token_num', 226)
                block_size = warmup_state['block_size']
                block_num = attn_map.size(1) // block_size
                text_block_num = text_token_num // block_size
                video_block_num = block_num - text_block_num
                num_heads = attn_map.size(0)
                device = attn_map.device
                # 如果是第0步，保存S_0
                thereshold = warmup_state['thereshold']
                block_sparse_map = convert_to_block_sparse_map(warmup_state, attn_map, text_token_num, block_size, thereshold)
            update_warmup(
                warmup_state,
                current_step,
                block_sparse_map,
            )
            mask_info = {'mode': 'warmup', 'mask': None}
        
        # 当 Warmup 最后一步完成时，预先规划第一个掩码
        if current_step == warmup_state['warmup_steps'] - 1:
            # torch.cuda.synchronize()
            # start_time = time.time()
            mask_info = predict_mask_from_warmup(
                warmup_state,
                n,
                top_k=top_k,
                layer_idx=layer_idx
            )
            if warmup_state['predict_all']:
                warmup_state['mask_list'] = mask_info['mask']
                warmup_state['mask'] = mask_info['mask']  # 实际是下一步要使用的mask
            else:
                # warmup_state['mask_list'] = mask_info['mask']
                warmup_state['mask'] = mask_info['mask'] # 实际是下一步要使用的mask
            print(f"Layer {layer_idx}, Mask Sparsity: {1 - warmup_state['mask'].float().mean():.3%}")
            warmup_state['mode'] = mask_info['mode']
        return output, warmup_state['mask']
    
    else:
        # ====================================================================
        # 并行模式
        # ====================================================================

        # --- [A] 并行计算 (Run + Analysis) ---
        if warmup_state.get('mask', None) is not None:
            output = AttentionSparseEngine(
                query,
                key,
                value,
                warmup_state['mask'],
                pre_defined_mask=pre_defined_mask,
                video_token_num=warmup_state.get('video_token_num',0),
                block_size=warmup_state['block_size']
                )
            if (current_step - warmup_state['warmup_steps']) % warmup_state['predict_T'] == warmup_state['predict_T'] - 1:
                if warmup_state['predict_all']:
                    idx = (current_step - warmup_state['warmup_steps']) // warmup_state['predict_T'] + 1
                    warmup_state['mask'] = warmup_state['mask_list'][idx]
                else:
                    if warmup_state['use_block_attention']:
                        output, current_sparse_map = compute_block_sparse_map(query, key, value, warmup_state)
                    else:
                        with torch.no_grad():
                            output, attn_maps = compute_full_attention(query, key, value)
                        text_token_num = warmup_state.get('text_token_num', 226)
                        block_size = warmup_state['block_size']
                        block_num = attn_maps.size(1) // block_size
                        thereshold = warmup_state['thereshold']
                        current_sparse_map = convert_to_block_sparse_map(warmup_state, attn_maps, text_token_num, block_size, thereshold)            
                    if warmup_state['model_type']=='cogvideox':
                        current_mask = warmup_state['mask'][:, warmup_state['text_token_num']//block_size:, warmup_state['text_token_num']//block_size:]
                    if warmup_state['model_type']=='hunyuan':
                        current_mask = warmup_state['mask'][:, :warmup_state['video_token_num']//block_size, :warmup_state['video_token_num']//block_size]
                    print(warmup_state['mask'].shape, current_mask.shape, current_sparse_map.shape)
                    warmup_state['prev_full_maps'] = torch.where(
                        current_mask==False,
                        warmup_state['prev_full_maps'],
                        current_sparse_map
                    )

                    features = extract_attention_features(
                        warmup_state=warmup_state,
                        S_T=warmup_state['prev_full_maps'],
                        S_0=warmup_state['S_0_maps'],
                        blocks_per_frame=warmup_state['blocks_per_frame'],
                        use_cuda=warmup_state['use_cuda']
                    )
                    
                    warmup_state['features_history'].append(features)
                    warmup_state['features_history'].pop(0) 

                    selected_lines = choose_topk_lines(
                        warmup_state['features_history'], 
                        top_k=warmup_state['top_k'],
                        predict_all=warmup_state['predict_all'], 
                        predict_T=warmup_state.get('predict_T', None)
                    )
                    next_mask = generate_mask_from_lines(
                        warmup_state, 
                        num_heads, 
                        n, 
                        warmup_state['blocks_per_frame'], 
                        selected_lines, 
                        warmup_state['mode'], 
                        device=query.device
                    )
                    
                    # 保存 T+1 的掩码
                    warmup_state['mask'] = next_mask
                warmup_state['thereshold_ratio'] = 1 / (warmup_state['mask'].float().mean(dim=(1,2)))
        mask_info = {
            'mode': warmup_state['mode'],
            'mask': warmup_state['mask'], # 返回的是 *下一步* 的 mask
        }
        
        # 返回 (output, mask_info)
        return output, mask_info



