"""Utils for training/fine-tuning scripts."""

import torch

from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX

def expand_mask_with_prefix(mask):
    """
    将 (b, s) 维度的mask扩展为 (b, 512+s) 维度
    """
    batch_size, seq_len = mask.shape
    
    # 创建512个True的前缀
    prefix = torch.ones(batch_size, 512, dtype=mask.dtype, device=mask.device)
    
    # 拼接
    expanded_mask = torch.cat([prefix, mask], dim=1)
    
    return expanded_mask

def modify_mask_vectorized(mask):
    """
    向量化版本：将mask中每个batch的最后70个True置为False
    """
    modified_mask = mask.clone()
    batch_size, seq_len = mask.shape
    
    # 为每个batch创建索引
    batch_indices = torch.arange(batch_size).unsqueeze(1)  # (batch_size, 1)
    
    # 找到每行最后一个True的位置
    # 通过cumsum找到True的累积计数，然后找到每行的最大值（即True的总数）
    cumsum = mask.cumsum(dim=1)  # (batch_size, seq_len)
    num_trues = cumsum[:, -1]    # (batch_size,) 每行True的总数
    
    # 创建位置索引
    pos_indices = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)  # (batch_size, seq_len)
    
    # 对于每个batch，找到需要置为False的位置
    # 即：position >= (num_trues - 70) 且 mask为True的位置
    threshold = (num_trues - 70).clamp(min=0).unsqueeze(1)  # (batch_size, 1)
    should_be_false = (cumsum > threshold) & mask
    
    # 将这些位置置为False
    modified_mask[should_be_false] = False
    
    return modified_mask

def get_current_action_mask(token_ids):
    # Create a tensor marking positions of IGNORE_INDEX
    newline_positions = token_ids != IGNORE_INDEX

    # Calculate cumulative sum to identify regions between newlines
    cumsum = torch.cumsum(newline_positions, dim=1)

    # Create the mask
    mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)

    # Extract the action part only
    action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
    mask = action_tokens_only_mask * mask

    return mask


def get_next_actions_mask(token_ids):
    # Create a tensor marking positions of IGNORE_INDEX
    newline_positions = token_ids != IGNORE_INDEX

    # Calculate cumulative sum to identify regions between newlines
    cumsum = torch.cumsum(newline_positions, dim=1)

    # Create the mask
    mask = cumsum > ACTION_DIM

    # Extract the action part only
    action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
    mask = action_tokens_only_mask * mask

    return mask


def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
    correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
    accuracy = correct_preds.sum().float() / mask.sum().float()
    return accuracy


def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
    pred_continuous_actions = torch.tensor(
        action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
    )
    true_continuous_actions = torch.tensor(
        action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
    )
    l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
    return l1_loss
