"""
VLM trajectory dataset for VLM SFT training.

This module provides dataset classes for training VLM components from reasoning trajectories.
Since we're only tuning the VLM (captioner), we extract VLM-relevant parts:
- Original image descriptions vs. refined descriptions
- Answers to specific VLM requests  
- Answers to generic questions about images

For VLM training, we use TRL's multimodal format with "messages" and "images" fields.
"""

import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import AutoTokenizer
import datasets
from datasets import Dataset as HFDataset


@dataclass
class VLMTrainingExample:
    """
    A VLM training example extracted from reasoning trajectory.
    
    This represents a specific VLM task: description refinement, 
    question answering, or request handling.
    """
    # Core data
    image_path: str
    messages: List[Dict[str, str]]  # Conversation format for TRL VLM training
    
    # Example type
    example_type: str  # "description", "question", "request"
    
    # Metadata
    trajectory_id: str = ""
    dataset_name: str = "trajectory"
    category: Optional[str] = None
    difficulty: Optional[str] = None


@dataclass
class ReasoningTrajectory:
    """
    A single reasoning trajectory example.
    
    Contains all the information needed for extracting VLM training examples.
    """
    # Input data
    image_path: str
    question: str
    
    # Reasoning process
    vlm_description: str
    reasoning_steps: List[str]
    final_answer: str
    
    # Metadata
    dataset_name: str = "unknown"
    sample_id: str = ""
    difficulty: Optional[str] = None
    category: Optional[str] = None
    
    # Additional context
    ground_truth: Optional[str] = None
    confidence_score: Optional[float] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'image_path': self.image_path,
            'question': self.question,
            'vlm_description': self.vlm_description,
            'reasoning_steps': self.reasoning_steps,
            'final_answer': self.final_answer,
            'dataset_name': self.dataset_name,
            'sample_id': self.sample_id,
            'difficulty': self.difficulty,
            'category': self.category,
            'ground_truth': self.ground_truth,
            'confidence_score': self.confidence_score,
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningTrajectory':
        """Create from dictionary."""
        return cls(**data)
    
    def extract_vlm_examples(
        self, 
        include_description: bool = True,
        include_question: bool = True,
        include_requests: bool = True,
        system_prompt: Optional[str] = None
    ) -> List[VLMTrainingExample]:
        """
        Extract VLM training examples from this trajectory.
        
        Args:
            include_description: Include description refinement examples
            include_question: Include question answering examples
            include_requests: Include request handling examples
            system_prompt: System prompt for conversations
            
        Returns:
            List of VLM training examples
        """
        examples = []
        
        if not system_prompt:
            system_prompt = "You are a helpful AI assistant that can analyze images and provide detailed descriptions and answers."
        
        # 1. Description refinement example
        if include_description and self.vlm_description:
            description_messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": "Please provide a detailed description of this image."},
                {"role": "assistant", "content": self.vlm_description}
            ]
            
            examples.append(VLMTrainingExample(
                image_path=self.image_path,
                messages=description_messages,
                example_type="description",
                trajectory_id=self.sample_id,
                dataset_name=self.dataset_name,
                category=self.category,
                difficulty=self.difficulty,
            ))
        
        # 2. Question answering example
        if include_question and self.question and self.final_answer:
            question_messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Based on this image, {self.question}"},
                {"role": "assistant", "content": self.final_answer}
            ]
            
            examples.append(VLMTrainingExample(
                image_path=self.image_path,
                messages=question_messages,
                example_type="question",
                trajectory_id=self.sample_id,
                dataset_name=self.dataset_name,
                category=self.category,
                difficulty=self.difficulty,
            ))
        
        # 3. Request handling examples
        if include_requests and self.reasoning_steps:
            # Extract specific requests from reasoning steps
            for i, step in enumerate(self.reasoning_steps):
                # Look for patterns that indicate VLM requests
                if any(keyword in step.lower() for keyword in ['describe', 'identify', 'what do you see', 'analyze']):
                    # Create a request-response pair
                    request_messages = [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": f"Looking at this image: {step}"},
                        {"role": "assistant", "content": self.vlm_description}
                    ]
                    
                    examples.append(VLMTrainingExample(
                        image_path=self.image_path,
                        messages=request_messages,
                        example_type="request",
                        trajectory_id=f"{self.sample_id}_step_{i}",
                        dataset_name=self.dataset_name,
                        category=self.category,
                        difficulty=self.difficulty,
                    ))
        
        return examples


class VLMTrajectoryDataset(Dataset):
    """
    PyTorch dataset for VLM training from reasoning trajectories.
    
    This dataset extracts VLM-relevant examples from reasoning trajectories
    and formats them for TRL VLM training with "messages" and "images" fields.
    """
    
    def __init__(
        self,
        data_path: Union[str, Path, List[str]],
        tokenizer: AutoTokenizer,
        image_processor,
        max_length: int = 2048,
        split: str = "train",
        filter_by_category: Optional[List[str]] = None,
        filter_by_difficulty: Optional[List[str]] = None,
        include_description: bool = True,
        include_question: bool = True,
        include_requests: bool = True,
        system_prompt: Optional[str] = None,
    ):
        """
        Initialize VLM trajectory dataset.
        
        Args:
            data_path: Path to JSONL file(s) or directory containing trajectory data
            tokenizer: Tokenizer for text processing
            image_processor: Image processor for vision models
            max_length: Maximum sequence length
            split: Dataset split ("train", "eval", "test")
            filter_by_category: Filter by specific categories
            filter_by_difficulty: Filter by specific difficulty levels
            include_description: Include description refinement examples
            include_question: Include question answering examples
            include_requests: Include request handling examples
            system_prompt: System prompt for conversations
        """
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.split = split
        self.system_prompt = system_prompt
        
        # Load trajectories and extract VLM examples
        trajectories = self._load_trajectories(
            data_path, filter_by_category, filter_by_difficulty
        )
        
        # Extract VLM training examples
        self.vlm_examples = []
        for trajectory in trajectories:
            examples = trajectory.extract_vlm_examples(
                include_description=include_description,
                include_question=include_question,
                include_requests=include_requests,
                system_prompt=system_prompt,
            )
            self.vlm_examples.extend(examples)
        
        print(f"Loaded {len(trajectories)} trajectories")
        print(f"Extracted {len(self.vlm_examples)} VLM training examples for {split}")
        
        # Count by type
        type_counts = {}
        for example in self.vlm_examples:
            type_counts[example.example_type] = type_counts.get(example.example_type, 0) + 1
        print(f"Example types: {type_counts}")
    
    def __len__(self) -> int:
        return len(self.vlm_examples)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a single VLM training example.
        
        Returns:
            Dictionary with "messages", "images", and metadata for TRL VLM training
        """
        example = self.vlm_examples[idx]
        
        # Prepare result in TRL VLM format
        result = {
            "messages": example.messages,  # TRL expects this format
        }
        
        # Load and process image
        try:
            image = Image.open(example.image_path)
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # TRL expects "images" field
            result["images"] = [image]  # List of PIL images
            
        except Exception as e:
            print(f"Warning: Could not load image {example.image_path}: {e}")
            # Create a dummy image
            dummy_image = Image.new("RGB", (224, 224), color="black")
            result["images"] = [dummy_image]
        
        # Add metadata
        result["vlm_metadata"] = {
            "example_type": example.example_type,
            "trajectory_id": example.trajectory_id,
            "dataset_name": example.dataset_name,
            "category": example.category,
            "difficulty": example.difficulty,
        }
        
        return result
    
    def _load_trajectories(
        self,
        data_path: Union[str, Path, List[str]],
        filter_by_category: Optional[List[str]] = None,
        filter_by_difficulty: Optional[List[str]] = None,
    ) -> List[ReasoningTrajectory]:
        """Load trajectories from file(s)."""
        trajectories = []
        
        # Handle multiple paths
        if isinstance(data_path, (list, tuple)):
            for path in data_path:
                trajectories.extend(self._load_single_path(path))
        else:
            trajectories.extend(self._load_single_path(data_path))
        
        # Apply filters
        if filter_by_category:
            trajectories = [t for t in trajectories if t.category in filter_by_category]
        
        if filter_by_difficulty:
            trajectories = [t for t in trajectories if t.difficulty in filter_by_difficulty]
        
        return trajectories
    
    def _load_single_path(self, data_path: Union[str, Path]) -> List[ReasoningTrajectory]:
        """Load trajectories from a single path."""
        data_path = Path(data_path)
        trajectories = []
        
        if data_path.is_file():
            # Single file
            if data_path.suffix == ".jsonl":
                trajectories.extend(self._load_jsonl(data_path))
            elif data_path.suffix == ".json":
                trajectories.extend(self._load_json(data_path))
        elif data_path.is_dir():
            # Directory - load all JSONL files
            for file_path in data_path.glob("*.jsonl"):
                trajectories.extend(self._load_jsonl(file_path))
            for file_path in data_path.glob("*.json"):
                trajectories.extend(self._load_json(file_path))
        else:
            raise FileNotFoundError(f"Data path not found: {data_path}")
        
        return trajectories
    
    def _load_jsonl(self, file_path: Path) -> List[ReasoningTrajectory]:
        """Load trajectories from JSONL file."""
        trajectories = []
        
        with open(file_path, 'r') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    trajectory = ReasoningTrajectory.from_dict(data)
                    
                    # Validate required fields for VLM training
                    if not all([trajectory.image_path, trajectory.vlm_description]):
                        print(f"Warning: Skipping trajectory missing VLM data at line {line_num}")
                        continue
                    
                    trajectories.append(trajectory)
                except Exception as e:
                    print(f"Warning: Error loading trajectory at line {line_num}: {e}")
        
        return trajectories
    
    def _load_json(self, file_path: Path) -> List[ReasoningTrajectory]:
        """Load trajectories from JSON file."""
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        trajectories = []
        
        # Handle different JSON structures
        if isinstance(data, list):
            # List of trajectory dictionaries
            for item in data:
                try:
                    trajectory = ReasoningTrajectory.from_dict(item)
                    trajectories.append(trajectory)
                except Exception as e:
                    print(f"Warning: Error loading trajectory: {e}")
        elif isinstance(data, dict) and "trajectories" in data:
            # Wrapper with metadata
            for item in data["trajectories"]:
                try:
                    trajectory = ReasoningTrajectory.from_dict(item)
                    trajectories.append(trajectory)
                except Exception as e:
                    print(f"Warning: Error loading trajectory: {e}")
        else:
            raise ValueError(f"Unsupported JSON format in {file_path}")
        
        return trajectories
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get dataset statistics."""
        if not self.vlm_examples:
            return {}
        
        # Count by example type
        type_counts = {}
        dataset_counts = {}
        category_counts = {}
        difficulty_counts = {}
        
        for example in self.vlm_examples:
            # Type counts
            type_counts[example.example_type] = type_counts.get(example.example_type, 0) + 1
            
            # Dataset counts
            dataset_counts[example.dataset_name] = dataset_counts.get(example.dataset_name, 0) + 1
            
            # Category counts
            if example.category:
                category_counts[example.category] = category_counts.get(example.category, 0) + 1
            
            # Difficulty counts
            if example.difficulty:
                difficulty_counts[example.difficulty] = difficulty_counts.get(example.difficulty, 0) + 1
        
        return {
            "total_vlm_examples": len(self.vlm_examples),
            "example_type_distribution": type_counts,
            "dataset_distribution": dataset_counts,
            "category_distribution": category_counts,
            "difficulty_distribution": difficulty_counts,
        }


def create_huggingface_dataset(
    vlm_examples: List[VLMTrainingExample],
) -> HFDataset:
    """
    Create HuggingFace dataset from VLM examples.
    
    Args:
        vlm_examples: List of VLM training examples
        
    Returns:
        HuggingFace dataset suitable for TRL VLM training
    """
    # Convert to HF dataset format
    data = {
        "messages": [example.messages for example in vlm_examples],
        "images": [example.image_path for example in vlm_examples],  # Paths for now
        "example_type": [example.example_type for example in vlm_examples],
        "trajectory_id": [example.trajectory_id for example in vlm_examples],
        "dataset_name": [example.dataset_name for example in vlm_examples],
    }
    
    return HFDataset.from_dict(data)


def load_vlm_trajectory_dataset(
    data_path: Union[str, Path],
    tokenizer: AutoTokenizer,
    image_processor,
    split: str = "train",
    **kwargs
) -> VLMTrajectoryDataset:
    """
    Convenience function to load VLM trajectory dataset.
    
    Args:
        data_path: Path to trajectory data
        tokenizer: Tokenizer for text processing
        image_processor: Image processor for vision models
        split: Dataset split
        **kwargs: Additional arguments for VLMTrajectoryDataset
        
    Returns:
        Loaded VLM trajectory dataset
    """
    return VLMTrajectoryDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        image_processor=image_processor,
        split=split,
        **kwargs
    ) 