"""
Reasoning trajectory dataset for supervised fine-tuning.

This module provides dataset classes for training on reasoning trajectories,
where each example contains image, question, and step-by-step reasoning.
"""

import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import AutoTokenizer
import datasets
from datasets import Dataset as HFDataset


@dataclass
class ReasoningTrajectory:
    """
    A single reasoning trajectory example.
    
    Contains all the information needed for training a reasoning scaffold,
    including intermediate steps and final answer.
    """
    # Input data
    image_path: str
    question: str
    
    # Reasoning process
    vlm_description: str
    reasoning_steps: List[str]
    final_answer: str
    
    # Metadata
    dataset_name: str = "unknown"
    sample_id: str = ""
    difficulty: Optional[str] = None
    category: Optional[str] = None
    
    # Additional context
    ground_truth: Optional[str] = None
    confidence_score: Optional[float] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'image_path': self.image_path,
            'question': self.question,
            'vlm_description': self.vlm_description,
            'reasoning_steps': self.reasoning_steps,
            'final_answer': self.final_answer,
            'dataset_name': self.dataset_name,
            'sample_id': self.sample_id,
            'difficulty': self.difficulty,
            'category': self.category,
            'ground_truth': self.ground_truth,
            'confidence_score': self.confidence_score,
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningTrajectory':
        """Create from dictionary."""
        return cls(**data)
    
    def format_for_training(self, format_type: str = "chatml") -> str:
        """
        Format trajectory for training.
        
        Args:
            format_type: Format type ("chatml", "plain", "instructional", "scaffold")
            
        Returns:
            Formatted text for training
        """
        if format_type == "chatml":
            return self._format_chatml()
        elif format_type == "plain":
            return self._format_plain()
        elif format_type == "instructional":
            return self._format_instructional()
        elif format_type == "scaffold":
            return self._format_scaffold()
        else:
            raise ValueError(f"Unknown format type: {format_type}")
    
    def _format_chatml(self) -> str:
        """Format using ChatML format."""
        text = "<|im_start|>system\n"
        text += "You are a helpful AI assistant that can analyze images and solve problems step by step.\n"
        text += "<|im_end|>\n"
        
        text += "<|im_start|>user\n"
        text += f"<image>\n{self.question}\n"
        text += "<|im_end|>\n"
        
        text += "<|im_start|>assistant\n"
        text += f"I'll analyze this image step by step.\n\n"
        text += f"**Image Description:**\n{self.vlm_description}\n\n"
        text += f"**Reasoning:**\n"
        for i, step in enumerate(self.reasoning_steps, 1):
            text += f"{i}. {step}\n"
        text += f"\n**Answer:** {self.final_answer}\n"
        text += "<|im_end|>"
        
        return text
    
    def _format_plain(self) -> str:
        """Format as plain text."""
        text = f"Question: {self.question}\n\n"
        text += f"Image Description: {self.vlm_description}\n\n"
        text += f"Reasoning:\n"
        for i, step in enumerate(self.reasoning_steps, 1):
            text += f"{i}. {step}\n"
        text += f"\nAnswer: {self.final_answer}"
        return text
    
    def _format_instructional(self) -> str:
        """Format with clear instruction structure."""
        text = "Given the following image and question, provide a step-by-step analysis.\n\n"
        text += f"Question: {self.question}\n\n"
        text += f"Analysis:\n"
        text += f"First, I'll describe what I see in the image:\n{self.vlm_description}\n\n"
        text += f"Now I'll work through this step by step:\n"
        for i, step in enumerate(self.reasoning_steps, 1):
            text += f"Step {i}: {step}\n"
        text += f"\nTherefore, the answer is: {self.final_answer}"
        return text

    def _format_scaffold(self) -> str:
        """Format using our current scaffold prompt structure."""
        text = f"Based on the following image description, please answer the question:\n\n"
        text += f"Image Description: {self.vlm_description}\n\n"
        text += f"Question: {self.question}\n\n"
        
        # Use the scaffold format with structured reasoning
        text += f"Reasoning: "
        for i, step in enumerate(self.reasoning_steps):
            if i > 0:
                text += f" "
            text += step
        text += f"\n"
        text += f"Status: SOLVED\n"
        text += f"Answer: {self.final_answer}\n"
        text += f"Request: N/A"
        
        return text


class ReasoningTrajectoryDataset(Dataset):
    """
    PyTorch dataset for reasoning trajectories.
    
    This dataset loads and processes reasoning trajectories for training
    VLM components in reasoning scaffolds.
    """
    
    def __init__(
        self,
        data_path: Union[str, Path, List[str]],
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
        format_type: str = "chatml",
        image_processor = None,
        split: str = "train",
        filter_by_category: Optional[List[str]] = None,
        filter_by_difficulty: Optional[List[str]] = None,
        use_loss_masking: bool = True,
        mask_prompt_tokens: bool = True,
    ):
        """
        Initialize reasoning trajectory dataset.
        
        Args:
            data_path: Path to JSONL file(s) or directory containing trajectory data
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
            format_type: Text formatting type ("chatml", "alpaca", "plain", "scaffold")
            image_processor: Image processor (optional)
            split: Dataset split ("train", "eval", "test")
            filter_by_category: Filter by specific categories
            filter_by_difficulty: Filter by specific difficulty levels
            use_loss_masking: Whether to apply loss masking for response-only training
            mask_prompt_tokens: Whether to mask prompt tokens in loss computation
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.format_type = format_type
        self.image_processor = image_processor
        self.split = split
        self.use_loss_masking = use_loss_masking
        self.mask_prompt_tokens = mask_prompt_tokens
        
        # Load trajectories
        self.trajectories = self._load_trajectories(
            data_path, filter_by_category, filter_by_difficulty
        )
        
        print(f"Loaded {len(self.trajectories)} reasoning trajectories for {split}")
        if use_loss_masking:
            print(f"Loss masking enabled - only computing loss on response tokens")
    
    def __len__(self) -> int:
        return len(self.trajectories)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a single trajectory example.
        
        Args:
            idx: Index
            
        Returns:
            Dictionary with tokenized inputs and labels
        """
        trajectory = self.trajectories[idx]
        
        if self.use_loss_masking and self.mask_prompt_tokens:
            # Use loss masking - separate prompt and response for masking
            return self._get_item_with_loss_masking(trajectory)
        else:
            # Standard approach - train on entire sequence
            return self._get_item_standard(trajectory)
    
    def _get_item_standard(self, trajectory: ReasoningTrajectory) -> Dict[str, torch.Tensor]:
        """Get item without loss masking."""
        # Format text for training
        text = trajectory.format_for_training(self.format_type)
        
        # Tokenize
        tokenized = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Prepare labels (same as input_ids for causal LM)
        labels = tokenized["input_ids"].clone()
        
        result = {
            "input_ids": tokenized["input_ids"].squeeze(),
            "attention_mask": tokenized["attention_mask"].squeeze(),
            "labels": labels.squeeze(),
        }
        
        return self._add_image_and_metadata(result, trajectory)
    
    def _get_item_with_loss_masking(self, trajectory: ReasoningTrajectory) -> Dict[str, torch.Tensor]:
        """Get item with loss masking applied to prompt tokens."""
        # Create prompt and full text separately for masking
        if self.format_type == "chatml":
            prompt_text = self._create_chatml_prompt(trajectory)
            full_text = trajectory.format_for_training(self.format_type)
        elif self.format_type == "alpaca":
            prompt_text = f"### Instruction:\n{trajectory.question}\n\n### Response:\n"
            full_text = trajectory.format_for_training(self.format_type)
        elif self.format_type == "scaffold":
            prompt_text = f"Based on the following image description, please answer the question:\n\nImage Description: {trajectory.vlm_description}\n\nQuestion: {trajectory.question}\n\n"
            full_text = trajectory.format_for_training(self.format_type)
        else:
            # Plain format
            prompt_text = f"Question: {trajectory.question}\n\nAnswer: "
            full_text = trajectory.format_for_training(self.format_type)
        
        # Tokenize full text
        full_tokens = self.tokenizer(
            full_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Tokenize prompt to get length for masking
        prompt_tokens = self.tokenizer(
            prompt_text,
            truncation=True,
            padding=False,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        prompt_length = prompt_tokens["input_ids"].shape[1]
        
        # Create labels with masking
        labels = full_tokens["input_ids"].clone()
        labels[:, :prompt_length] = -100  # Mask prompt tokens
        
        result = {
            "input_ids": full_tokens["input_ids"].squeeze(),
            "attention_mask": full_tokens["attention_mask"].squeeze(),
            "labels": labels.squeeze(),
        }
        
        return self._add_image_and_metadata(result, trajectory)
    
    def _create_chatml_prompt(self, trajectory: ReasoningTrajectory) -> str:
        """Create ChatML prompt without response."""
        text = "<|im_start|>system\n"
        text += "You are a helpful AI assistant that can analyze images and solve problems step by step.\n"
        text += "<|im_end|>\n"
        text += "<|im_start|>user\n"
        text += f"<image>\n{trajectory.question}\n"
        text += "<|im_end|>\n"
        text += "<|im_start|>assistant\n"
        return text
    
    def _add_image_and_metadata(self, result: Dict[str, torch.Tensor], trajectory: ReasoningTrajectory) -> Dict[str, torch.Tensor]:
        """Add image and metadata to result dictionary."""
        # Add image if processor is available
        if self.image_processor and trajectory.image_path:
            try:
                image = Image.open(trajectory.image_path)
                if image.mode != "RGB":
                    image = image.convert("RGB")
                
                processed_image = self.image_processor(image, return_tensors="pt")
                result["pixel_values"] = processed_image["pixel_values"].squeeze()
            except Exception as e:
                print(f"Warning: Could not load image {trajectory.image_path}: {e}")
                # Add dummy image tensor if needed
                if "pixel_values" not in result:
                    result["pixel_values"] = torch.zeros(3, 224, 224)
        
        # Add metadata
        result["trajectory_metadata"] = {
            "sample_id": trajectory.sample_id,
            "dataset_name": trajectory.dataset_name,
            "category": trajectory.category,
            "difficulty": trajectory.difficulty,
        }
        
        return result
    
    def _load_trajectories(
        self,
        data_path: Union[str, Path, List[str]],
        filter_by_category: Optional[List[str]] = None,
        filter_by_difficulty: Optional[List[str]] = None,
    ) -> List[ReasoningTrajectory]:
        """Load trajectories from file(s)."""
        trajectories = []
        
        # Handle multiple paths
        if isinstance(data_path, (list, tuple)):
            for path in data_path:
                trajectories.extend(self._load_single_path(path))
        else:
            trajectories.extend(self._load_single_path(data_path))
        
        # Apply filters
        if filter_by_category:
            trajectories = [t for t in trajectories if t.category in filter_by_category]
        
        if filter_by_difficulty:
            trajectories = [t for t in trajectories if t.difficulty in filter_by_difficulty]
        
        return trajectories
    
    def _load_single_path(self, data_path: Union[str, Path]) -> List[ReasoningTrajectory]:
        """Load trajectories from a single path."""
        data_path = Path(data_path)
        trajectories = []
        
        if data_path.is_file():
            # Single file
            if data_path.suffix == ".jsonl":
                trajectories.extend(self._load_jsonl(data_path))
            elif data_path.suffix == ".json":
                trajectories.extend(self._load_json(data_path))
        elif data_path.is_dir():
            # Directory - load all JSONL files
            for file_path in data_path.glob("*.jsonl"):
                trajectories.extend(self._load_jsonl(file_path))
            for file_path in data_path.glob("*.json"):
                trajectories.extend(self._load_json(file_path))
        else:
            raise FileNotFoundError(f"Data path not found: {data_path}")
        
        return trajectories
    
    def _load_jsonl(self, file_path: Path) -> List[ReasoningTrajectory]:
        """Load trajectories from JSONL file."""
        trajectories = []
        
        with open(file_path, 'r') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    trajectory = ReasoningTrajectory.from_dict(data)
                    
                    # Validate required fields
                    if not all([trajectory.image_path, trajectory.question, trajectory.final_answer]):
                        print(f"Warning: Skipping incomplete trajectory at line {line_num}")
                        continue
                    
                    trajectories.append(trajectory)
                except Exception as e:
                    print(f"Warning: Error loading trajectory at line {line_num}: {e}")
        
        return trajectories
    
    def _load_json(self, file_path: Path) -> List[ReasoningTrajectory]:
        """Load trajectories from JSON file."""
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        trajectories = []
        
        # Handle different JSON structures
        if isinstance(data, list):
            # List of trajectory dictionaries
            for item in data:
                try:
                    trajectory = ReasoningTrajectory.from_dict(item)
                    trajectories.append(trajectory)
                except Exception as e:
                    print(f"Warning: Error loading trajectory: {e}")
        elif isinstance(data, dict) and "trajectories" in data:
            # Wrapper with metadata
            for item in data["trajectories"]:
                try:
                    trajectory = ReasoningTrajectory.from_dict(item)
                    trajectories.append(trajectory)
                except Exception as e:
                    print(f"Warning: Error loading trajectory: {e}")
        else:
            raise ValueError(f"Unsupported JSON format in {file_path}")
        
        return trajectories
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get dataset statistics."""
        if not self.trajectories:
            return {}
        
        # Count by dataset
        dataset_counts = {}
        category_counts = {}
        difficulty_counts = {}
        
        total_reasoning_steps = 0
        step_lengths = []
        
        for trajectory in self.trajectories:
            # Dataset counts
            dataset_counts[trajectory.dataset_name] = dataset_counts.get(trajectory.dataset_name, 0) + 1
            
            # Category counts
            if trajectory.category:
                category_counts[trajectory.category] = category_counts.get(trajectory.category, 0) + 1
            
            # Difficulty counts
            if trajectory.difficulty:
                difficulty_counts[trajectory.difficulty] = difficulty_counts.get(trajectory.difficulty, 0) + 1
            
            # Reasoning step statistics
            total_reasoning_steps += len(trajectory.reasoning_steps)
            step_lengths.append(len(trajectory.reasoning_steps))
        
        return {
            "total_trajectories": len(self.trajectories),
            "dataset_distribution": dataset_counts,
            "category_distribution": category_counts,
            "difficulty_distribution": difficulty_counts,
            "average_reasoning_steps": total_reasoning_steps / len(self.trajectories),
            "min_reasoning_steps": min(step_lengths) if step_lengths else 0,
            "max_reasoning_steps": max(step_lengths) if step_lengths else 0,
        }
    
    def save_trajectories(self, output_path: Union[str, Path], format: str = "jsonl"):
        """
        Save trajectories to file.
        
        Args:
            output_path: Path to save to
            format: Output format ("jsonl" or "json")
        """
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        if format == "jsonl":
            with open(output_path, 'w') as f:
                for trajectory in self.trajectories:
                    f.write(json.dumps(trajectory.to_dict()) + '\n')
        elif format == "json":
            data = {
                "metadata": {
                    "total_trajectories": len(self.trajectories),
                    "created_by": "ReasoningTrajectoryDataset",
                },
                "trajectories": [t.to_dict() for t in self.trajectories]
            }
            with open(output_path, 'w') as f:
                json.dump(data, f, indent=2)
        else:
            raise ValueError(f"Unsupported format: {format}")


def create_huggingface_dataset(
    trajectories: List[ReasoningTrajectory],
    format_type: str = "chatml"
) -> HFDataset:
    """
    Create a HuggingFace dataset from reasoning trajectories.
    
    Args:
        trajectories: List of reasoning trajectories
        format_type: Text formatting type
        
    Returns:
        HuggingFace Dataset
    """
    data = []
    
    for trajectory in trajectories:
        text = trajectory.format_for_training(format_type)
        
        data.append({
            "text": text,
            "image_path": trajectory.image_path,
            "question": trajectory.question,
            "final_answer": trajectory.final_answer,
            "sample_id": trajectory.sample_id,
            "dataset_name": trajectory.dataset_name,
            "category": trajectory.category,
            "difficulty": trajectory.difficulty,
        })
    
    return HFDataset.from_list(data)


def load_reasoning_trajectories(
    data_path: Union[str, Path],
    split: str = "train",
    **kwargs
) -> ReasoningTrajectoryDataset:
    """
    Convenience function to load reasoning trajectories.
    
    Args:
        data_path: Path to trajectory data
        split: Dataset split
        **kwargs: Additional arguments for ReasoningTrajectoryDataset
        
    Returns:
        Loaded dataset
    """
    # Default tokenizer if not provided
    if "tokenizer" not in kwargs:
        from transformers import AutoTokenizer
        kwargs["tokenizer"] = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
    
    return ReasoningTrajectoryDataset(
        data_path=data_path,
        split=split,
        **kwargs
    ) 