"""
ViRL-39K dataset loader for RL training.

This module provides a dataset loader for the ViRL-39K (Visual instruction tuning with Reinforcement Learning) 
dataset, specifically adapted for RL training (PPO/GRPO) of VLM components.

For RL training, we only provide the question/prompt and the final answer separately.
The model generates responses during training which are then evaluated by reward functions.
"""

import json
import logging
from pathlib import Path
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass

import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from PIL import Image

logger = logging.getLogger(__name__)


@dataclass
class ViRLSample:
    """
    A single ViRL-39K sample for RL training.
    
    For RL training, we need:
    - question/prompt (what the model sees)
    - answer (for reward calculation, not during generation)
    - metadata (difficulty, category, etc. for analysis)
    """
    # Core data for RL
    question: str
    answer: str
    image_path: Optional[str] = None
    
    # Metadata for reward functions and filtering
    qid: str = ""
    category: Optional[str] = None
    passrate: Optional[Dict[str, float]] = None  # Difficulty proxy
    dataset_subset: str = "virl"
    
    # Optional context
    instruction: Optional[str] = None  # Original instruction if different from question
    
    def get_difficulty_score(self) -> float:
        """Get difficulty score from passrate (lower passrate = higher difficulty)."""
        if not self.passrate:
            return 0.5  # Default medium difficulty
        
        # Use average passrate across models as difficulty proxy
        avg_passrate = sum(self.passrate.values()) / len(self.passrate)
        return 1.0 - avg_passrate  # Invert so higher score = more difficult


class ViRL39KDataset(Dataset):
    """
    ViRL-39K dataset for RL training with proper GRPO format.
    
    This dataset loads ViRL-39K data and formats it for GRPO training,
    matching the exact format expected by TRL's GRPOTrainer.
    """
    
    def __init__(
        self,
        data_path: Union[str, Path],
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
        rl_framework: str = "grpo",  # "grpo" or "ppo"
        image_processor = None,
        split: str = "train",
        filter_by_subset: Optional[List[str]] = None,
        filter_by_category: Optional[List[str]] = None,
        min_difficulty: Optional[float] = None,
        max_difficulty: Optional[float] = None,
        system_prompt: Optional[str] = None,
    ):
        self.data_path = Path(data_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.rl_framework = rl_framework
        self.image_processor = image_processor
        self.split = split
        
        # System prompt for InternVL3 compatibility
        self.system_prompt = system_prompt or "You are a helpful AI assistant that can analyze images and answer questions."
        
        # Load and filter samples
        self.samples = self._load_virl_samples(
            data_path, filter_by_subset, filter_by_category, 
            min_difficulty, max_difficulty, split
        )
        
        logger.info(f"Loaded {len(self.samples)} ViRL samples for {split} split")
        
        # Validate we have the right format
        if len(self.samples) > 0:
            sample_format = self._get_grpo_format(self.samples[0])
            logger.info(f"Sample format validation: {list(sample_format.keys())}")
            if "prompt" not in sample_format:
                raise ValueError("Dataset format must include 'prompt' key for GRPO training")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        sample = self.samples[idx]
        
        if self.rl_framework == "grpo":
            return self._get_grpo_format(sample)
        elif self.rl_framework == "ppo":
            return self._get_ppo_format(sample)
        else:
            raise ValueError(f"Unsupported RL framework: {self.rl_framework}")

    def _get_grpo_format(self, sample: ViRLSample) -> Dict[str, Any]:
        """
        Format sample for GRPO training to match TRL GRPOTrainer expectations.
        
        Returns a format that matches what prepare_grpo_data.py creates:
        - "prompt": List of messages with role/content (conversational format)
        - "ground_truth": Answer for reward calculation
        - Additional metadata for reward functions
        """
        # Create conversational format with system message
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self.system_prompt}]
            }
        ]
        
        # Create user message with structured content for multimodal
        if sample.image_path:
            # Validate image exists
            image_path = Path(sample.image_path)
            if not image_path.exists():
                logger.warning(f"Image not found: {image_path}")
                # Fall back to text-only format
                user_content = sample.question
            else:
                # Use structured content format for InternVL3
                user_content = [
                    {"type": "image", "url": str(image_path)},
                    {"type": "text", "text": sample.question}
                ]
        else:
            # Text-only format
            user_content = [{"type": "text", "text": sample.question}]
        
        messages.append({
            "role": "user", 
            "content": user_content
        })
        
        # Return format matching TRL GRPO expectations
        return {
            "prompt": messages,  # Required by GRPO trainer
            "ground_truth": sample.answer,  # For reward calculation
            # Additional metadata for reward functions
            "qid": sample.qid,
            "category": sample.category,
            "dataset_subset": sample.dataset_subset,
            "difficulty_score": sample.get_difficulty_score(),
            "passrate": sample.passrate,
        }

    def _get_ppo_format(self, sample: ViRLSample) -> Dict[str, Any]:
        """
        Format sample for PPO training.
        
        PPO typically expects full conversations with query/response pairs.
        """
        # Create conversational format
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self.system_prompt}]
            }
        ]
        
        # User message (same as GRPO)
        if sample.image_path:
            image_path = Path(sample.image_path)
            if not image_path.exists():
                logger.warning(f"Image not found: {image_path}")
                user_content = sample.question
            else:
                user_content = [
                    {"type": "image", "url": str(image_path)},
                    {"type": "text", "text": sample.question}
                ]
        else:
            user_content = sample.question
        
        messages.append({
            "role": "user",
            "content": user_content
        })
        
        # For PPO, we might include the ground truth response
        messages.append({
            "role": "assistant",
            "content": sample.answer
        })
        
        return {
            "prompt": messages,
            "ground_truth": sample.answer,
            "qid": sample.qid,
            "category": sample.category,
            "dataset_subset": sample.dataset_subset,
            "difficulty_score": sample.get_difficulty_score(),
            "passrate": sample.passrate,
        }

    def _load_virl_samples(
        self,
        data_path: Union[str, Path],
        filter_by_subset: Optional[List[str]] = None,
        filter_by_category: Optional[List[str]] = None,
        min_difficulty: Optional[float] = None,
        max_difficulty: Optional[float] = None,
        split: str = "train"
    ) -> List[ViRLSample]:
        """Load ViRL samples from various file formats."""
        data_path = Path(data_path)
        
        samples = []
        
        if data_path.is_file():
            # Single file
            if data_path.suffix == '.jsonl':
                samples = self._load_jsonl_samples(data_path)
            elif data_path.suffix == '.json':
                samples = self._load_json_samples(data_path)
            else:
                raise ValueError(f"Unsupported file format: {data_path.suffix}")
        
        elif data_path.is_dir():
            # Directory with split files
            split_file = data_path / f"{split}.jsonl"
            if split_file.exists():
                samples = self._load_jsonl_samples(split_file)
            else:
                # Try to find any jsonl files
                jsonl_files = list(data_path.glob("*.jsonl"))
                if jsonl_files:
                    logger.warning(f"Split file {split_file} not found, using {jsonl_files[0]}")
                    samples = self._load_jsonl_samples(jsonl_files[0])
                else:
                    raise FileNotFoundError(f"No suitable data files found in {data_path}")
        
        else:
            raise FileNotFoundError(f"Data path not found: {data_path}")
        
        # Apply filters
        if filter_by_subset:
            samples = [s for s in samples if s.dataset_subset in filter_by_subset]
        
        if filter_by_category:
            samples = [s for s in samples if s.category in filter_by_category]
        
        if min_difficulty is not None:
            samples = [s for s in samples if s.get_difficulty_score() >= min_difficulty]
        
        if max_difficulty is not None:
            samples = [s for s in samples if s.get_difficulty_score() <= max_difficulty]
        
        logger.info(f"Loaded {len(samples)} samples after filtering")
        return samples

    def _load_jsonl_samples(self, file_path: Path) -> List[ViRLSample]:
        """Load samples from JSONL file."""
        samples = []
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    sample = self._parse_virl_sample(data)
                    if sample:
                        samples.append(sample)
                except json.JSONDecodeError as e:
                    logger.warning(f"Invalid JSON on line {line_num}: {e}")
                except Exception as e:
                    logger.warning(f"Error parsing line {line_num}: {e}")
        
        return samples

    def _load_json_samples(self, file_path: Path) -> List[ViRLSample]:
        """Load samples from JSON file."""
        samples = []
        
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # Handle different JSON structures
        if isinstance(data, list):
            # List of samples
            for item in data:
                sample = self._parse_virl_sample(item)
                if sample:
                    samples.append(sample)
        elif isinstance(data, dict):
            # Single sample or nested structure
            sample = self._parse_virl_sample(data)
            if sample:
                samples.append(sample)
        
        return samples

    def _parse_virl_sample(self, data: Dict[str, Any]) -> Optional[ViRLSample]:
        """Parse a single ViRL sample from dictionary."""
        try:
            # Extract core fields
            question = self._extract_question(data)
            answer = self._extract_answer(data)
            
            if not question or not answer:
                logger.warning(f"Missing question or answer in sample: {data.get('qid', 'unknown')}")
                return None
            
            # Extract metadata
            qid = data.get('qid', data.get('id', ''))
            category = data.get('category', data.get('subject'))
            dataset_subset = data.get('dataset_subset', data.get('subset', 'virl'))
            image_path = data.get('image_path', data.get('image'))
            if (not image_path or not question) and "prompt" in data:
                for msg in data["prompt"]:
                    if msg.get("role") != "user":
                        continue
                    content = msg.get("content", "")
                    if isinstance(content, list):              # multimodal
                        for part in content:
                            if isinstance(part, dict):
                                if part.get("type") == "image" and not image_path:
                                    image_path = part.get("url")
                                elif part.get("type") == "text" and not question:
                                    question = part.get("text", "").strip()
                    elif isinstance(content, str) and not question:  # text-only
                        question = content.strip()
            passrate = self._extract_passrate(data)
            
            return ViRLSample(
                question=question,
                answer=answer,
                image_path=image_path,
                qid=qid,
                category=category,
                passrate=passrate,
                dataset_subset=dataset_subset
            )
            
        except Exception as e:
            logger.warning(f"Error parsing sample: {e}")
            return None

    def _extract_question(self, data: Dict[str, Any]) -> str:
        """Extract question from various possible keys."""
        question_keys = [
            'question', 'query', 'prompt', 'instruction', 
            'input', 'text', 'problem', 'content'
        ]
        
        for key in question_keys:
            if key in data and data[key]:
                question = data[key]
                if isinstance(question, str):
                    return question.strip()
                elif isinstance(question, list) and len(question) > 0:
                    # Handle list format (e.g., conversational)
                    if isinstance(question[0], dict):
                        # Extract from conversational format
                        user_msgs = [msg for msg in question if msg.get('role') == 'user']
                        if user_msgs:
                            content = user_msgs[-1].get('content', '')
                            if isinstance(content, str):
                                return content.strip()
                            elif isinstance(content, list):
                                # Extract text from structured content
                                text_parts = [part.get('text', '') for part in content if part.get('type') == 'text']
                                return ' '.join(text_parts).strip()
                    else:
                        return str(question[0]).strip()
        
        return ""

    def _extract_answer(self, data: Dict[str, Any]) -> str:
        """Extract answer from various possible keys."""
        answer_keys = [
            'answer', 'response', 'completion', 'output', 
            'solution', 'ground_truth', 'target'
        ]
        
        for key in answer_keys:
            if key in data and data[key]:
                answer = data[key]
                if isinstance(answer, str):
                    return answer.strip()
                elif isinstance(answer, list) and len(answer) > 0:
                    return str(answer[0]).strip()
        
        return ""

    def _extract_passrate(self, data: Dict[str, Any]) -> Optional[Dict[str, float]]:
        """Extract passrate information for difficulty scoring."""
        passrate_keys = ['passrate', 'pass_rate', 'accuracy', 'performance']
        
        for key in passrate_keys:
            if key in data and data[key]:
                passrate = data[key]
                if isinstance(passrate, dict):
                    # Convert to float values
                    return {k: float(v) for k, v in passrate.items() if isinstance(v, (int, float))}
                elif isinstance(passrate, (int, float)):
                    return {'default': float(passrate)}
        
        return None

    def get_statistics(self) -> Dict[str, Any]:
        """Get dataset statistics."""
        if not self.samples:
            return {}
        
        stats = {
            'total_samples': len(self.samples),
            'multimodal_samples': sum(1 for s in self.samples if s.image_path),
            'text_only_samples': sum(1 for s in self.samples if not s.image_path),
        }
        
        # Category distribution
        categories = [s.category for s in self.samples if s.category]
        if categories:
            from collections import Counter
            stats['category_distribution'] = dict(Counter(categories))
        
        # Difficulty distribution
        difficulties = [s.get_difficulty_score() for s in self.samples]
        if difficulties:
            stats['difficulty_stats'] = {
                'mean': sum(difficulties) / len(difficulties),
                'min': min(difficulties),
                'max': max(difficulties)
            }
        
        # Dataset subset distribution
        subsets = [s.dataset_subset for s in self.samples]
        if subsets:
            from collections import Counter
            stats['subset_distribution'] = dict(Counter(subsets))
        
        return stats


def load_virl_dataset(
    data_path: Union[str, Path],
    tokenizer: AutoTokenizer,
    rl_framework: str = "grpo",
    split: str = "train",
    **kwargs
) -> ViRL39KDataset:
    """
    Convenience function to load ViRL dataset.
    
    Args:
        data_path: Path to ViRL data
        tokenizer: Tokenizer for the model
        rl_framework: RL framework ("grpo" or "ppo")
        split: Data split ("train", "eval", "test")
        **kwargs: Additional arguments for ViRL39KDataset
    
    Returns:
        ViRL39KDataset instance
    """
    return ViRL39KDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        rl_framework=rl_framework,
        split=split,
        **kwargs
    ) 