"""
Supervised Fine-Tuning (SFT) dataset for VLM training.

This module provides an SFT dataset loader specifically designed for VLM training
using TRL's multimodal format with "messages" and "images" fields. Since we're 
only tuning the VLM (captioner), this focuses on VLM-relevant tasks.
"""

import json
import torch
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from torch.utils.data import Dataset
from PIL import Image
from transformers import AutoTokenizer
from dataclasses import dataclass

from .trajectory_dataset import ReasoningTrajectory


@dataclass
class SFTSample:
    """
    A single SFT training sample.
    
    This represents a standard instruction-response pair that can be used
    for supervised fine-tuning of VLM models.
    """
    # Core data
    instruction: str
    response: str
    image_path: Optional[str] = None
    
    # Optional system message
    system_message: Optional[str] = None
    
    # Multi-turn conversation support
    conversation: Optional[List[Dict[str, str]]] = None
    
    # Metadata
    sample_id: str = ""
    dataset_name: str = "sft"
    category: Optional[str] = None
    difficulty: Optional[str] = None
    
    # Quality metrics
    quality_score: Optional[float] = None
    human_verified: bool = False
    
    def to_conversation_format(self) -> List[Dict[str, str]]:
        """Convert to conversation format."""
        if self.conversation:
            return self.conversation
        
        conversation = []
        
        if self.system_message:
            conversation.append({
                "role": "system",
                "content": self.system_message
            })
        
        conversation.append({
            "role": "user", 
            "content": self.instruction
        })
        
        conversation.append({
            "role": "assistant",
            "content": self.response
        })
        
        return conversation


class SFTDataset(Dataset):
    """
    Supervised Fine-Tuning dataset for VLM training.
    
    This dataset handles various instruction-following formats and converts them
    to training examples suitable for VLM fine-tuning.
    """
    
    def __init__(
        self,
        data_path: Union[str, Path, List[str]],
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
        format_type: str = "chatml",
        image_processor = None,
        split: str = "train",
        system_message: Optional[str] = None,
        filter_by_category: Optional[List[str]] = None,
        filter_by_quality: Optional[float] = None,
        use_loss_masking: bool = True,
        mask_system_tokens: bool = True,
        include_images: bool = True,
    ):
        """
        Initialize SFT dataset.
        
        Args:
            data_path: Path to SFT data file(s) or directory
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
            format_type: Text formatting type ("chatml", "alpaca", "plain")
            image_processor: Image processor for vision models
            split: Dataset split ("train", "val", "test")
            system_message: Default system message if not in data
            filter_by_category: Filter by specific categories
            filter_by_quality: Minimum quality score threshold
            use_loss_masking: Whether to apply loss masking for instruction tuning
            mask_system_tokens: Whether to mask system message tokens in loss
            include_images: Whether to include image processing
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.format_type = format_type
        self.image_processor = image_processor
        self.split = split
        self.system_message = system_message
        self.use_loss_masking = use_loss_masking
        self.mask_system_tokens = mask_system_tokens
        self.include_images = include_images
        
        # Load SFT samples
        self.samples = self._load_sft_samples(
            data_path, filter_by_category, filter_by_quality, split
        )
        
        print(f"Loaded {len(self.samples)} SFT samples for {split}")
        if use_loss_masking:
            print("Loss masking enabled - only computing loss on response tokens")
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """
        Get a single SFT sample for VLM training.
        
        Returns:
            Dictionary with "messages", "images", and metadata for TRL VLM training
        """
        sample = self.samples[idx]
        
        # Convert to TRL VLM format
        result = {
            "messages": sample.to_conversation_format(),  # TRL expects this format
        }
        
        # Load and process image
        if self.include_images and sample.image_path:
            try:
                image = Image.open(sample.image_path)
                if image.mode != "RGB":
                    image = image.convert("RGB")
                
                # TRL expects "images" field with list of PIL images
                result["images"] = [image]
                
            except Exception as e:
                print(f"Warning: Could not load image {sample.image_path}: {e}")
                # Create a dummy image
                dummy_image = Image.new("RGB", (224, 224), color="black")
                result["images"] = [dummy_image]
        
        # Add metadata
        result["sample_metadata"] = {
            "sample_id": sample.sample_id,
            "dataset_name": sample.dataset_name,
            "category": sample.category,
            "difficulty": sample.difficulty,
            "quality_score": sample.quality_score,
        }
        
        return result
    
    def _load_sft_samples(
        self,
        data_path: Union[str, Path, List[str]],
        filter_by_category: Optional[List[str]] = None,
        filter_by_quality: Optional[float] = None,
        split: str = "train"
    ) -> List[SFTSample]:
        """Load SFT samples from file(s)."""
        samples = []
        
        # Handle multiple paths
        if isinstance(data_path, (list, tuple)):
            for path in data_path:
                samples.extend(self._load_single_path(path, split))
        else:
            samples.extend(self._load_single_path(data_path, split))
        
        # Apply filters
        if filter_by_category:
            samples = [s for s in samples if s.category in filter_by_category]
        
        if filter_by_quality:
            samples = [s for s in samples if s.quality_score and s.quality_score >= filter_by_quality]
        
        return samples
    
    def _load_single_path(self, data_path: Union[str, Path], split: str) -> List[SFTSample]:
        """Load samples from a single path."""
        data_path = Path(data_path)
        samples = []
        
        if data_path.is_file():
            if data_path.suffix == ".jsonl":
                samples.extend(self._load_jsonl_samples(data_path))
            elif data_path.suffix == ".json":
                samples.extend(self._load_json_samples(data_path))
        elif data_path.is_dir():
            # Look for split-specific files
            split_files = list(data_path.glob(f"*{split}*.jsonl"))
            if not split_files:
                split_files = list(data_path.glob("*.jsonl"))
            
            for file_path in split_files:
                samples.extend(self._load_jsonl_samples(file_path))
                
            # Also check for JSON files
            json_files = list(data_path.glob(f"*{split}*.json"))
            if not json_files:
                json_files = list(data_path.glob("*.json"))
                
            for file_path in json_files:
                samples.extend(self._load_json_samples(file_path))
        else:
            raise FileNotFoundError(f"SFT data path not found: {data_path}")
        
        return samples
    
    def _load_jsonl_samples(self, file_path: Path) -> List[SFTSample]:
        """Load samples from JSONL file."""
        samples = []
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    sample = self._parse_sft_sample(data)
                    
                    if sample:
                        samples.append(sample)
                except Exception as e:
                    print(f"Warning: Error loading SFT sample at line {line_num}: {e}")
        
        return samples
    
    def _load_json_samples(self, file_path: Path) -> List[SFTSample]:
        """Load samples from JSON file."""
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        samples = []
        
        if isinstance(data, list):
            for item in data:
                sample = self._parse_sft_sample(item)
                if sample:
                    samples.append(sample)
        elif isinstance(data, dict) and "samples" in data:
            for item in data["samples"]:
                sample = self._parse_sft_sample(item)
                if sample:
                    samples.append(sample)
        elif isinstance(data, dict) and "data" in data:
            for item in data["data"]:
                sample = self._parse_sft_sample(item)
                if sample:
                    samples.append(sample)
        
        return samples
    
    def _parse_sft_sample(self, data: Dict[str, Any]) -> Optional[SFTSample]:
        """Parse a single SFT sample from dictionary."""
        try:
            # Handle TRL messages format (preferred)
            if "messages" in data:
                messages = data["messages"]
                system_msg = None
                instruction = ""
                response = ""
                image_path = data.get("image", data.get("image_path", None))
                
                for message in messages:
                    role = message.get("role", "")
                    content = message.get("content", "")
                    
                    if role == "system":
                        system_msg = content
                    elif role == "user":
                        instruction = content
                    elif role == "assistant":
                        response = content
                
                # Extract metadata
                sample_id = data.get("sample_id", f"sample_{hash(str(messages))}")
                category = data.get("category", None)
                
                return SFTSample(
                    instruction=instruction,
                    response=response,
                    image_path=image_path,
                    system_message=system_msg or self.system_message,
                    conversation=None,  # Don't duplicate - use messages format
                    sample_id=str(sample_id),
                    dataset_name=data.get("source_dataset", "sft"),
                    category=category,
                    difficulty=data.get("difficulty", None),
                    quality_score=data.get("quality_score", None),
                    human_verified=data.get("was_correct", False),
                )
            
            # Handle conversation format (legacy)
            elif "conversations" in data:
                conversation = data["conversations"]
                system_msg = None
                instruction = ""
                response = ""
                image_path = None
                
                for turn in conversation:
                    role = turn.get("from", turn.get("role", ""))
                    content = turn.get("value", turn.get("content", ""))
                    
                    # Handle structured content (lists with type/content items)
                    if isinstance(content, list):
                        text_content = ""
                        for item in content:
                            if isinstance(item, dict):
                                if item.get("type") == "text":
                                    text_content = item.get("text", "")
                                elif item.get("type") == "image":
                                    if not image_path:  # Use first image found
                                        image_path = item.get("image", "")
                        content = text_content
                    
                    if role in ["system"]:
                        system_msg = content
                    elif role in ["user", "human"]:
                        instruction = content
                    elif role in ["assistant", "gpt", "bot"]:
                        response = content
                
                # Extract metadata
                metadata = data.get("metadata", {})
                sample_id = metadata.get("qid", data.get("id", f"sample_{hash(str(conversation))}"))
                category = metadata.get("category", data.get("category", None))
                
                return SFTSample(
                    instruction=instruction,
                    response=response,
                    image_path=image_path or data.get("image", data.get("image_path", None)),
                    system_message=system_msg or self.system_message,
                    conversation=conversation,
                    sample_id=str(sample_id),
                    dataset_name=metadata.get("source_dataset", data.get("dataset", "sft")),
                    category=category,
                    difficulty=data.get("difficulty", None),
                    quality_score=data.get("quality_score", None),
                    human_verified=metadata.get("was_correct", data.get("human_verified", False)),
                )
            
            # Handle simple instruction-response format
            else:
                instruction = data.get("instruction", data.get("input", data.get("question", "")))
                response = data.get("output", data.get("response", data.get("answer", "")))
                
                if not instruction or not response:
                    return None
                
                return SFTSample(
                    instruction=instruction,
                    response=response,
                    image_path=data.get("image", data.get("image_path", None)),
                    system_message=data.get("system", self.system_message),
                    sample_id=data.get("id", str(hash(instruction))),
                    dataset_name=data.get("dataset", "sft"),
                    category=data.get("category", data.get("task_type", None)),
                    difficulty=data.get("difficulty", None),
                    quality_score=data.get("quality_score", None),
                    human_verified=data.get("human_verified", False),
                )
                
        except Exception as e:
            print(f"Warning: Failed to parse SFT sample: {e}")
            return None
    
    def _format_chatml_conversation(self, sample: SFTSample) -> str:
        """Format conversation using ChatML format."""
        text = ""
        
        # Add system message
        system_msg = sample.system_message or "You are a helpful AI assistant."
        text += f"<|im_start|>system\n{system_msg}<|im_end|>\n"
        
        # Add user instruction
        text += f"<|im_start|>user\n"
        if sample.image_path:
            text += "<image>\n"
        text += f"{sample.instruction}<|im_end|>\n"
        
        # Add assistant response
        text += f"<|im_start|>assistant\n{sample.response}<|im_end|>"
        
        return text
    
    def _format_alpaca_conversation(self, sample: SFTSample) -> str:
        """Format conversation using Alpaca format."""
        text = "### Instruction:\n"
        text += f"{sample.instruction}\n\n"
        text += "### Response:\n"
        text += f"{sample.response}"
        
        return text
    
    def _format_plain_conversation(self, sample: SFTSample) -> str:
        """Format conversation using plain text."""
        text = f"User: {sample.instruction}\n\n"
        text += f"Assistant: {sample.response}"
        
        return text
    
    def _format_custom_conversation(self, sample: SFTSample) -> str:
        """Format conversation using custom format type."""
        # Default to ChatML for unknown formats
        return self._format_chatml_conversation(sample)
    
    def _create_instruction_only(self, sample: SFTSample) -> str:
        """Create instruction-only text for loss masking."""
        if self.format_type == "chatml":
            text = ""
            
            # Add system message if not masking system tokens
            if not self.mask_system_tokens and sample.system_message:
                text += f"<|im_start|>system\n{sample.system_message}<|im_end|>\n"
            elif sample.system_message:
                text += f"<|im_start|>system\n{sample.system_message}<|im_end|>\n"
            
            text += f"<|im_start|>user\n"
            if sample.image_path:
                text += "<image>\n"
            text += f"{sample.instruction}<|im_end|>\n"
            text += f"<|im_start|>assistant\n"
            
        elif self.format_type == "alpaca":
            text = "### Instruction:\n"
            text += f"{sample.instruction}\n\n"
            text += "### Response:\n"
            
        else:  # plain
            text = f"User: {sample.instruction}\n\nAssistant: "
            
        return text
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get dataset statistics."""
        if not self.samples:
            return {}
        
        # Count by dataset and category
        dataset_counts = {}
        category_counts = {}
        difficulty_counts = {}
        
        quality_scores = []
        human_verified_count = 0
        with_images_count = 0
        
        for sample in self.samples:
            # Dataset counts
            dataset_counts[sample.dataset_name] = dataset_counts.get(sample.dataset_name, 0) + 1
            
            # Category counts
            if sample.category:
                category_counts[sample.category] = category_counts.get(sample.category, 0) + 1
            
            # Difficulty counts
            if sample.difficulty:
                difficulty_counts[sample.difficulty] = difficulty_counts.get(sample.difficulty, 0) + 1
            
            # Quality statistics
            if sample.quality_score:
                quality_scores.append(sample.quality_score)
            
            if sample.human_verified:
                human_verified_count += 1
            
            if sample.image_path:
                with_images_count += 1
        
        stats = {
            "total_samples": len(self.samples),
            "dataset_distribution": dataset_counts,
            "category_distribution": category_counts,
            "difficulty_distribution": difficulty_counts,
            "human_verified_count": human_verified_count,
            "with_images_count": with_images_count,
            "multimodal_percentage": (with_images_count / len(self.samples)) * 100,
        }
        
        if quality_scores:
            stats.update({
                "avg_quality_score": sum(quality_scores) / len(quality_scores),
                "min_quality_score": min(quality_scores),
                "max_quality_score": max(quality_scores),
            })
        
        return stats

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """
        Collate a batch of samples into tensors for training.
        
        Args:
            batch: List of samples from the dataset
            
        Returns:
            Dictionary with input tensors
        """
        # Separate text and images
        conversations = []
        images = []
        
        for sample in batch:
            # Get the conversation messages
            messages = sample['messages']
            conversations.append(messages)
            
            # Get the image if present
            if 'image' in sample and sample['image'] is not None:
                images.append(sample['image'])
            else:
                images.append(None)
        
        # Process each conversation using chat template
        processed_texts = []
        processed_images = []
        
        for i, (messages, image) in enumerate(zip(conversations, images)):
            try:
                # We need to manually insert <IMG_CONTEXT> token for images
                # The chat template doesn't do this automatically
                modified_messages = []
                
                for msg in messages:
                    if msg['role'] == 'user' and image is not None:
                        # Insert <IMG_CONTEXT> at the beginning of user message with image
                        modified_content = f"<IMG_CONTEXT>\n{msg['content']}"
                        modified_messages.append({
                            'role': msg['role'],
                            'content': modified_content
                        })
                    else:
                        modified_messages.append(msg)
                
                # Now apply chat template to the modified messages
                if hasattr(self.processor, 'apply_chat_template'):
                    formatted_text = self.processor.apply_chat_template(
                        modified_messages, 
                        tokenize=False,
                        add_generation_prompt=False
                    )
                    processed_texts.append(formatted_text)
                    processed_images.append(image)
                else:
                    # Fallback for processors without chat template
                    text_parts = []
                    for msg in modified_messages:
                        role = msg['role']
                        content = msg['content']
                        if role == 'user':
                            text_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
                        else:  # assistant
                            text_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
                    
                    formatted_text = "\n".join(text_parts)
                    processed_texts.append(formatted_text)
                    processed_images.append(image)
                    
            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                # Skip this sample
                continue
        
        if not processed_texts:
            raise ValueError("No valid samples in batch")
        
        # Process the formatted texts and images with the processor
        try:
            # Filter out None images and corresponding texts for batch processing
            valid_pairs = [(text, img) for text, img in zip(processed_texts, processed_images) if img is not None]
            text_only_samples = [text for text, img in zip(processed_texts, processed_images) if img is None]
            
            all_inputs = []
            
            # Process samples with images
            if valid_pairs:
                texts_with_images = [pair[0] for pair in valid_pairs]
                images_for_processing = [pair[1] for pair in valid_pairs]
                
                inputs_with_images = self.processor(
                    text=texts_with_images,
                    images=images_for_processing,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=self.max_length
                )
                all_inputs.append(inputs_with_images)
            
            # Process text-only samples
            if text_only_samples:
                inputs_text_only = self.processor(
                    text=text_only_samples,
                    return_tensors="pt", 
                    padding=True,
                    truncation=True,
                    max_length=self.max_length
                )
                all_inputs.append(inputs_text_only)
            
            # Combine all inputs
            if len(all_inputs) == 1:
                combined_inputs = all_inputs[0]
            else:
                # Combine inputs from different processing calls
                combined_inputs = {}
                for key in all_inputs[0].keys():
                    combined_inputs[key] = torch.cat([inp[key] for inp in all_inputs], dim=0)
            
            # Ensure we have labels for training
            combined_inputs['labels'] = combined_inputs['input_ids'].clone()
            
            print(f"Batch processed: {len(processed_texts)} samples")
            if 'pixel_values' in combined_inputs:
                print(f"  - Image tiles shape: {combined_inputs['pixel_values'].shape}")
            print(f"  - Input IDs shape: {combined_inputs['input_ids'].shape}")
            
            return combined_inputs
            
        except Exception as e:
            print(f"Error in processor: {e}")
            print(f"Sample texts: {processed_texts[:2]}")  # Show first 2 for debugging
            raise


def load_sft_dataset(
    data_path: Union[str, Path],
    tokenizer: AutoTokenizer,
    split: str = "train",
    **kwargs
) -> SFTDataset:
    """
    Convenience function to load SFT dataset for VLM training.
    
    Args:
        data_path: Path to SFT data
        tokenizer: Tokenizer for text processing
        split: Dataset split
        **kwargs: Additional arguments for SFTDataset
        
    Returns:
        Loaded SFT dataset for VLM training
    """
    return SFTDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        split=split,
        **kwargs
    ) 