import torch

def apply_label_masking(examples, tokenizer, assistant_header_ids):
    """Apply label masking to only train on assistant responses"""
    texts = [tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in examples["prompt"]]
    tokenized = tokenizer(texts, max_length=768, truncation=True, padding=False, return_tensors=None, add_special_tokens=True)
    
    labels = []
    for i, input_ids in enumerate(tokenized["input_ids"]):
        # Check if last message is from user
        last_message_is_user = examples["prompt"][i][-1]["role"] == "user"
        
        if last_message_is_user:
            # Last message is user - train on entire sequence
            labels.append(input_ids[:])
        else:
            # Last message is assistant - apply masking
            start_idx = -1
            for j in range(len(input_ids) - len(assistant_header_ids) + 1):
                if input_ids[j:j + len(assistant_header_ids)] == assistant_header_ids:
                    start_idx = j + len(assistant_header_ids)
                    break
            labels.append([-100] * start_idx + input_ids[start_idx:] if start_idx != -1 else [-100] * len(input_ids))
    tokenized["labels"] = labels
    return tokenized

# Custom data collator to mask non-response tokens
class DataCollatorWithPaddingAndLabels:
    def __init__(self, tokenizer, label_pad_token_id: int = -100):
        self.tokenizer = tokenizer
        self.label_pad_token_id = label_pad_token_id

    def __call__(self, features):
        # Extract labels from each feature and remove them from the features dict.
        labels = [feature.pop("labels") for feature in features]

        # Use the tokenizer to pad the remaining fields ("input_ids", "attention_mask", etc.).
        batch = self.tokenizer.pad(features, return_tensors="pt")

        # Determine the max sequence length from the padded input_ids.
        max_length = batch["input_ids"].shape[1]

        # Pad the labels manually to match the max sequence length.
        padded_labels = []
        for label in labels:
            # Add pad tokens (using label_pad_token_id) to the right until each label reaches max_length.
            padded_label = label + [self.label_pad_token_id] * (max_length - len(label))
            padded_labels.append(padded_label)

        # Convert the list of padded labels into a tensor.
        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
        return batch