import torch
from typing import Dict, Optional
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, and_masks


def create_attention_masks(
    attention_mask: torch.Tensor,
    position_ids: torch.Tensor,
    bin_idx: torch.Tensor,
    section_mask: torch.Tensor,
    attn_types: list,
    group_size: int,
    use_block_mask: bool = False,
) -> Dict[str, torch.Tensor]:
    """
    Create attention masks based on the specified attention types.
    
    Args:
        attention_mask: [N, G*T] boolean mask indicating valid positions
        position_ids: [N, G*T] position indices
        bin_idx: [B, T] bin indices  
        section_mask: [B, T] section mask
        attn_types: list of attention types to create masks for
        group_size: number of agents in each group
        
    Returns:
        Dictionary mapping attention type names to mask tensors
    """
    N, GT = attention_mask.shape
    G = group_size
    T = GT // G
    B = N * G


    def create_block_mask_custom(mask_function, len):
        block_mask = create_block_mask(
            mask_mod=mask_function,
            B=N,
            H=None,
            Q_LEN=len,
            KV_LEN=len,
            device=attention_mask.device,
            _compile=True,
        )
        return block_mask

    
    # Create agent index mapping
    agent_idx = torch.arange(0, G, device=position_ids.device)[:,None].expand(-1, T).reshape(GT) # [G*T]

    # Process bin indices for bin-based masks
    filled_bin_idx = torch.zeros_like(bin_idx)
    indiv_idx = torch.zeros_like(bin_idx)
    indiv_counter = 1
    for b in range(B):
        valid_indices = torch.where(bin_idx[b] != -100)[0]
        if len(valid_indices) > 0:
            for valid_idx in valid_indices:
                filled_bin_idx[b, section_mask[b]==section_mask[b, valid_idx]] = bin_idx[b, valid_idx].item()
                indiv_idx[b, section_mask[b]==section_mask[b, valid_idx]] = indiv_counter
                indiv_counter += 1
    filled_bin_idx = filled_bin_idx.reshape(N, GT) # [N, G*T]
    indiv_idx = indiv_idx.reshape(N, GT) # [N, G*T]

    is_last = torch.zeros_like(bin_idx, dtype=torch.bool)
    for b in range(B):
        last_section_b = section_mask[b].max().item()
        is_last[b] = section_mask[b] == last_section_b
    is_last = is_last.reshape(N, GT) # [N, G*T]

    valid_attention_mask = []
    valid_positions = []
    valid_agent_idx = []
    valid_filled_bin_idx = []
    valid_indiv_idx = []
    valid_is_last = []
    for n in range(N):
        attn_pos = attention_mask[n]  # [G*T]

        # Extract valid token data
        valid_attention_mask.append(attention_mask[n][attn_pos]) # [valid]
        valid_positions.append(position_ids[n][attn_pos]) # [valid]
        valid_agent_idx.append(agent_idx[attn_pos])
        valid_filled_bin_idx.append(filled_bin_idx[n][attn_pos])
        valid_indiv_idx.append(indiv_idx[n][attn_pos])
        valid_is_last.append(is_last[n][attn_pos])
    
    if use_block_mask:
        omni_mask = [lambda b, _, q_idx, kv_idx: valid_attention_mask[n][q_idx] & valid_attention_mask[n][kv_idx] for n in range(N)]
        causal_mask = [lambda b, _, q_idx, kv_idx: (valid_positions[n][q_idx] >= valid_positions[n][kv_idx]) for n in range(N)]
        intra_mask = [lambda b, _, q_idx, kv_idx: (valid_agent_idx[n][q_idx] == valid_agent_idx[n][kv_idx]) for n in range(N)]
        bin_mask = [lambda b, _, q_idx, kv_idx: (valid_filled_bin_idx[n][q_idx] == valid_filled_bin_idx[n][kv_idx]) for n in range(N)]
        indiv_mask = [lambda b, _, q_idx, kv_idx: (valid_indiv_idx[n][q_idx] == valid_indiv_idx[n][kv_idx]) for n in range(N)]
        last_mask = [lambda b, _, q_idx, kv_idx: (valid_is_last[n][q_idx] & valid_is_last[n][kv_idx]) for n in range(N)]
    else:
        omni_mask = [valid_attention_mask[n][None,None,:,None] & valid_attention_mask[n][None,None,None,:] for n in range(N)]
        causal_mask = [valid_positions[n][None,None,:,None] >= valid_positions[n][None,None,None,:] for n in range(N)]
        intra_mask = [valid_agent_idx[n][None,None,:,None] == valid_agent_idx[n][None,None,None,:] for n in range(N)]
        bin_mask = [valid_filled_bin_idx[n][None,None,:,None] == valid_filled_bin_idx[n][None,None,None,:] for n in range(N)]
        indiv_mask = [valid_indiv_idx[n][None,None,:,None] == valid_indiv_idx[n][None,None,None,:] for n in range(N)]
        last_mask = [valid_is_last[n][None,None,:,None] & valid_is_last[n][None,None,None,:] for n in range(N)]
    
    # Create masks dictionary based on requested attention types
    masks = {}
    
    if use_block_mask:
        valid_len = [len(valid_attention_mask[n]) for n in range(N)]
        for attn_type in attn_types:
            if attn_type == "omni":
                masks[attn_type] = [create_block_mask_custom(omni_mask[n], valid_len[n]) for n in range(N)]
            elif attn_type == "omni_intra":
                masks[attn_type] = [create_block_mask_custom(and_masks(omni_mask[n], intra_mask[n]), valid_len[n]) for n in range(N)]
            elif attn_type == "omni_bin":
                masks[attn_type] = [create_block_mask_custom(and_masks(omni_mask[n], bin_mask[n]), valid_len[n]) for n in range(N)]
            elif attn_type == "omni_indiv":
                masks[attn_type] = [create_block_mask_custom(and_masks(omni_mask[n], indiv_mask[n]), valid_len[n]) for n in range(N)]
            elif attn_type == "causal":
                masks[attn_type] = [create_block_mask_custom(causal_mask[n], valid_len[n]) for n in range(N)]
            elif attn_type == "causal_intra":
                masks[attn_type] = [create_block_mask_custom(and_masks(causal_mask[n], intra_mask[n]), valid_len[n]) for n in range(N)]
            elif attn_type == "causal_bin":
                masks[attn_type] = [create_block_mask_custom(and_masks(causal_mask[n], bin_mask[n]), valid_len[n]) for n in range(N)]
            elif attn_type == "causal_indiv":
                masks[attn_type] = [create_block_mask_custom(and_masks(causal_mask[n], indiv_mask[n]), valid_len[n]) for n in range(N)]
            elif attn_type == "recent":
                # TODO: Implement recent mask logic
                raise ValueError(f"Unknown attention type: {attn_type}")
            elif attn_type == "recent_bin":
                # TODO: Implement recent_bin mask logic
                raise ValueError(f"Unknown attention type: {attn_type}")
            elif attn_type == "last":
                masks[attn_type] = [create_block_mask_custom(last_mask[n], valid_len[n]) for n in range(N)]
            elif attn_type == "last_bin":
                masks[attn_type] = [create_block_mask_custom(and_masks(last_mask[n], bin_mask[n]), valid_len[n]) for n in range(N)]
            else:
                raise ValueError(f"Unknown attention type: {attn_type}")
    else:
        for attn_type in attn_types:
            if attn_type == "omni":
                masks[attn_type] = [omni_mask[n] for n in range(N)]
            elif attn_type == "omni_intra":
                masks[attn_type] = [intra_mask[n] & omni_mask[n] for n in range(N)]
            elif attn_type == "omni_bin":
                masks[attn_type] = [bin_mask[n] & omni_mask[n] for n in range(N)]
            elif attn_type == "omni_indiv":
                masks[attn_type] = [indiv_mask[n] & omni_mask[n] for n in range(N)]
            elif attn_type == "causal":
                masks[attn_type] = [causal_mask[n] for n in range(N)]
            elif attn_type == "causal_intra":
                masks[attn_type] = [intra_mask[n] & causal_mask[n] for n in range(N)]
            elif attn_type == "causal_bin":
                masks[attn_type] = [bin_mask[n] & causal_mask[n] for n in range(N)]
            elif attn_type == "causal_indiv":
                masks[attn_type] = [indiv_mask[n] & causal_mask[n] for n in range(N)]
            elif attn_type == "recent":
                # TODO: Implement recent mask logic
                raise ValueError(f"Unknown attention type: {attn_type}")
            elif attn_type == "recent_bin":
                # TODO: Implement recent_bin mask logic
                raise ValueError(f"Unknown attention type: {attn_type}")
            elif attn_type == "last":
                masks[attn_type] = [last_mask[n] for n in range(N)]
            elif attn_type == "last_bin":
                masks[attn_type] = [last_mask[n] & bin_mask[n] for n in range(N)]
            else:
                raise ValueError(f"Unknown attention type: {attn_type}")
    
    return masks