"""
Data utilities for reasoning framework training.

This module provides helper functions for data preprocessing, loss masking,
and other data-related operations used in training VLM components.
"""

import torch
from typing import Dict, List, Optional, Any, Union
from transformers import PreTrainedTokenizer


def create_loss_mask(
    input_ids: torch.Tensor,
    prompt_lengths: List[int],
    ignore_index: int = -100
) -> torch.Tensor:
    """
    Create a loss mask that ignores prompt tokens and only computes loss on response tokens.
    
    Args:
        input_ids: Input token IDs [batch_size, seq_len]
        prompt_lengths: List of prompt lengths for each example in the batch
        ignore_index: Index to use for ignored tokens (default: -100)
        
    Returns:
        Loss mask tensor with same shape as input_ids
    """
    batch_size, seq_len = input_ids.shape
    mask = torch.full_like(input_ids, ignore_index)
    
    for i, prompt_len in enumerate(prompt_lengths):
        # Only compute loss on response tokens (after prompt)
        if prompt_len < seq_len:
            mask[i, prompt_len:] = input_ids[i, prompt_len:]
    
    return mask


def apply_chat_template_with_loss_masking(
    examples: List[Dict[str, Any]],
    tokenizer: PreTrainedTokenizer,
    prompt_field: str = "question",
    response_field: str = "final_answer",
    format_type: str = "chatml",
    max_length: int = 2048
) -> Dict[str, torch.Tensor]:
    """
    Apply chat template and create labels with loss masking for response-only training.
    
    Args:
        examples: List of examples with prompt and response fields
        tokenizer: Tokenizer to use for formatting
        prompt_field: Field name containing the prompt/question
        response_field: Field name containing the response/answer
        format_type: Chat template format to use
        max_length: Maximum sequence length
        
    Returns:
        Dictionary with input_ids, attention_mask, and labels (with masking)
    """
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    
    for example in examples:
        # Format as chat template
        if format_type == "chatml":
            messages = [
                {"role": "user", "content": example[prompt_field]},
                {"role": "assistant", "content": example[response_field]}
            ]
            
            # Get the full formatted text
            full_text = tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=False
            )
            
            # Get prompt text (without response)
            prompt_messages = [{"role": "user", "content": example[prompt_field]}]
            prompt_text = tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=False,
                add_generation_prompt=True
            )
            
        else:
            # Simple format
            full_text = f"Question: {example[prompt_field]}\nAnswer: {example[response_field]}"
            prompt_text = f"Question: {example[prompt_field]}\nAnswer: "
        
        # Tokenize full text
        full_tokens = tokenizer(
            full_text,
            truncation=True,
            padding=False,
            max_length=max_length,
            return_tensors="pt"
        )
        
        # Tokenize prompt to get length for masking
        prompt_tokens = tokenizer(
            prompt_text,
            truncation=True,
            padding=False,
            max_length=max_length,
            return_tensors="pt"
        )
        
        prompt_length = prompt_tokens["input_ids"].shape[1]
        
        # Create labels with masking
        labels = full_tokens["input_ids"].clone()
        labels[:, :prompt_length] = -100  # Mask prompt tokens
        
        input_ids_list.append(full_tokens["input_ids"].squeeze())
        attention_mask_list.append(full_tokens["attention_mask"].squeeze())
        labels_list.append(labels.squeeze())
    
    # Pad sequences
    max_len = min(max([len(ids) for ids in input_ids_list]), max_length)
    
    padded_input_ids = []
    padded_attention_mask = []
    padded_labels = []
    
    for input_ids, attention_mask, labels in zip(input_ids_list, attention_mask_list, labels_list):
        # Truncate if necessary
        input_ids = input_ids[:max_len]
        attention_mask = attention_mask[:max_len]
        labels = labels[:max_len]
        
        # Pad to max_len
        pad_length = max_len - len(input_ids)
        if pad_length > 0:
            input_ids = torch.cat([input_ids, torch.tensor([tokenizer.pad_token_id] * pad_length)])
            attention_mask = torch.cat([attention_mask, torch.zeros(pad_length, dtype=torch.long)])
            labels = torch.cat([labels, torch.tensor([-100] * pad_length)])
        
        padded_input_ids.append(input_ids)
        padded_attention_mask.append(attention_mask)
        padded_labels.append(labels)
    
    return {
        "input_ids": torch.stack(padded_input_ids),
        "attention_mask": torch.stack(padded_attention_mask), 
        "labels": torch.stack(padded_labels)
    }


def format_reasoning_trajectory_for_training(
    trajectory: 'ReasoningTrajectory',
    tokenizer: PreTrainedTokenizer,
    format_type: str = "chatml",
    include_reasoning_steps: bool = True,
    mask_prompt: bool = True
) -> str:
    """
    Format a reasoning trajectory for training with proper structure.
    
    Args:
        trajectory: ReasoningTrajectory object
        tokenizer: Tokenizer for formatting
        format_type: Format type ("chatml", "alpaca", "plain")
        include_reasoning_steps: Whether to include step-by-step reasoning
        mask_prompt: Whether to mask the prompt tokens in loss computation
        
    Returns:
        Formatted text string ready for training
    """
    if format_type == "chatml":
        return _format_chatml_trajectory(trajectory, include_reasoning_steps)
    elif format_type == "alpaca":
        return _format_alpaca_trajectory(trajectory, include_reasoning_steps)
    elif format_type == "plain":
        return _format_plain_trajectory(trajectory, include_reasoning_steps)
    else:
        raise ValueError(f"Unknown format type: {format_type}")


def _format_chatml_trajectory(trajectory: 'ReasoningTrajectory', include_reasoning: bool) -> str:
    """Format trajectory using ChatML format."""
    text = "<|im_start|>system\n"
    text += "You are a helpful AI assistant that can analyze images and solve problems step by step.\n"
    text += "<|im_end|>\n"
    
    text += "<|im_start|>user\n"
    if hasattr(trajectory, 'image_path') and trajectory.image_path:
        text += f"<image>\n"
    text += f"{trajectory.question}\n"
    text += "<|im_end|>\n"
    
    text += "<|im_start|>assistant\n"
    if include_reasoning:
        text += f"I'll analyze this step by step.\n\n"
        if trajectory.vlm_description:
            text += f"**Image Description:**\n{trajectory.vlm_description}\n\n"
        text += f"**Reasoning:**\n"
        for i, step in enumerate(trajectory.reasoning_steps, 1):
            text += f"{i}. {step}\n"
        text += f"\n**Answer:** {trajectory.final_answer}\n"
    else:
        text += f"{trajectory.final_answer}\n"
    text += "<|im_end|>"
    
    return text


def _format_alpaca_trajectory(trajectory: 'ReasoningTrajectory', include_reasoning: bool) -> str:
    """Format trajectory using Alpaca format."""
    instruction = trajectory.question
    if hasattr(trajectory, 'image_path') and trajectory.image_path:
        instruction = f"[Image] {instruction}"
    
    if include_reasoning:
        response = f"I'll solve this step by step.\n\n"
        if trajectory.vlm_description:
            response += f"Image Description: {trajectory.vlm_description}\n\n"
        response += f"Reasoning:\n"
        for i, step in enumerate(trajectory.reasoning_steps, 1):
            response += f"{i}. {step}\n"
        response += f"\nAnswer: {trajectory.final_answer}"
    else:
        response = trajectory.final_answer
    
    return f"### Instruction:\n{instruction}\n\n### Response:\n{response}"


def _format_plain_trajectory(trajectory: 'ReasoningTrajectory', include_reasoning: bool) -> str:
    """Format trajectory as plain text."""
    text = f"Question: {trajectory.question}\n\n"
    
    if include_reasoning:
        if trajectory.vlm_description:
            text += f"Image Description: {trajectory.vlm_description}\n\n"
        text += f"Reasoning:\n"
        for i, step in enumerate(trajectory.reasoning_steps, 1):
            text += f"{i}. {step}\n"
        text += f"\nAnswer: {trajectory.final_answer}"
    else:
        text += f"Answer: {trajectory.final_answer}"
    
    return text


def compute_dataset_statistics(dataset: Union[List[Dict], 'torch.utils.data.Dataset']) -> Dict[str, Any]:
    """
    Compute statistics for a training dataset.
    
    Args:
        dataset: Dataset to analyze
        
    Returns:
        Dictionary with dataset statistics
    """
    stats = {
        "total_examples": len(dataset),
        "text_lengths": [],
        "token_counts": [],
    }
    
    # Sample some examples to compute statistics
    sample_size = min(1000, len(dataset))
    indices = torch.randperm(len(dataset))[:sample_size].tolist()
    
    for idx in indices:
        example = dataset[idx]
        if isinstance(example, dict):
            text = example.get("text", "")
            stats["text_lengths"].append(len(text))
            stats["token_counts"].append(len(text.split()))
    
    if stats["text_lengths"]:
        stats["avg_text_length"] = sum(stats["text_lengths"]) / len(stats["text_lengths"])
        stats["max_text_length"] = max(stats["text_lengths"])
        stats["min_text_length"] = min(stats["text_lengths"])
        
        stats["avg_token_count"] = sum(stats["token_counts"]) / len(stats["token_counts"])
        stats["max_token_count"] = max(stats["token_counts"])
        stats["min_token_count"] = min(stats["token_counts"])
    
    return stats 