#!/usr/bin/env python3
"""
Load QA datasets from HuggingFace and export unified JSONL format.
Supports: gsm8k (math), medmcqa (medical), mmlu_psychology (psychology), barexam_qa (legal)
"""
import argparse
import hashlib
import json
import logging
import re
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple

from datasets import load_dataset

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class DatasetProcessor:
    """Base class for dataset processing."""
    
    def __init__(self, domain: str):
        self.domain = domain
        self.skipped_counts: Dict[str, int] = {}
    
    def process_item(self, item: Dict[str, Any], item_id: int) -> Optional[Dict[str, Any]]:
        """Process a single dataset item. Override in subclasses."""
        raise NotImplementedError
    
    def skip_item(self, reason: str) -> None:
        """Track skipped items by reason."""
        self.skipped_counts[reason] = self.skipped_counts.get(reason, 0) + 1
        logger.debug(f"Skipping item: {reason}")


class GSM8KProcessor(DatasetProcessor):
    """Process GSM8K math dataset."""
    
    def __init__(self):
        super().__init__("math")
    
    def extract_numeric_answer(self, answer_text: str) -> Optional[float]:
        """Extract the final numeric answer from GSM8K answer text."""
        # Look for numbers in the answer text, take the last one
        # Common patterns: "#### 42", "The answer is 42", "42."
        numbers = re.findall(r'-?\d+(?:\.\d+)?', answer_text)
        if not numbers:
            return None
        
        try:
            return float(numbers[-1])
        except ValueError:
            return None
    
    def generate_choices(self, correct_answer: float) -> List[str]:
        """Generate 4 choices deterministically with correct answer at index 0."""
        # Convert to int if it's a whole number for cleaner display
        if correct_answer == int(correct_answer):
            correct_answer = int(correct_answer)
        
        # Generate distractors deterministically
        distractors = []
        
        if isinstance(correct_answer, int):
            # Integer distractors
            distractors = [
                correct_answer + 1,
                correct_answer - 1, 
                correct_answer + 2
            ]
        else:
            # Float distractors
            distractors = [
                correct_answer + 1.0,
                correct_answer - 1.0,
                correct_answer + 0.5
            ]
        
        # Ensure uniqueness and convert to strings
        choices = [str(correct_answer)]
        for dist in distractors:
            if str(dist) not in choices:
                choices.append(str(dist))
        
        # Pad with additional distractors if needed
        while len(choices) < 4:
            if isinstance(correct_answer, int):
                new_dist = correct_answer + len(choices) * 3
            else:
                new_dist = correct_answer + len(choices) * 1.5
            choices.append(str(new_dist))
        
        return choices[:4]  # Ensure exactly 4 choices
    
    def process_item(self, item: Dict[str, Any], item_id: int) -> Optional[Dict[str, Any]]:
        """Process GSM8K item."""
        question = item.get('question', '').strip()
        answer_text = item.get('answer', '').strip()
        
        if not question or not answer_text:
            self.skip_item("missing_question_or_answer")
            return None
        
        # Extract numeric answer
        numeric_answer = self.extract_numeric_answer(answer_text)
        if numeric_answer is None:
            self.skip_item("no_numeric_answer")
            return None
        
        # Generate choices with correct answer at index 0
        choices = self.generate_choices(numeric_answer)
        
        return {
            "id": item_id,
            "domain": self.domain,
            "question": question,
            "choices": choices,
            "gold_label": "A",  # Correct answer always at index 0
            "gold_answer": choices[0]
        }


class MedMCQAProcessor(DatasetProcessor):
    """Process MedMCQA medical dataset."""
    
    def __init__(self):
        super().__init__("medical")
    
    def process_item(self, item: Dict[str, Any], item_id: int) -> Optional[Dict[str, Any]]:
        """Process MedMCQA item."""
        question = item.get('question', '').strip()
        
        # Extract options - try common field names
        option_fields = ['opa', 'opb', 'opc', 'opd']
        choices = []
        
        for field in option_fields:
            option = item.get(field, '').strip()
            if not option:
                self.skip_item(f"missing_option_{field}")
                return None
            choices.append(option)
        
        if len(choices) != 4:
            self.skip_item("invalid_choice_count")
            return None
        
        # Get correct option (1-based index)
        cop = item.get('cop')
        if cop not in [1, 2, 3, 4]:
            self.skip_item(f"invalid_cop_{cop}")
            return None
        
        gold_label = ["A", "B", "C", "D"][cop - 1]
        gold_answer = choices[cop - 1]
        
        return {
            "id": item_id,
            "domain": self.domain,
            "question": question,
            "choices": choices,
            "gold_label": gold_label,
            "gold_answer": gold_answer
        }


class MMLUPsychologyProcessor(DatasetProcessor):
    """Process MMLU Psychology dataset."""
    
    def __init__(self):
        super().__init__("psychology")
    
    def process_item(self, item: Dict[str, Any], item_id: int) -> Optional[Dict[str, Any]]:
        """Process MMLU Psychology item."""
        question = item.get('question', '').strip()
        choices = item.get('choices', [])
        answer = item.get('answer')
        
        if not question:
            self.skip_item("missing_question")
            return None
        
        if not isinstance(choices, list) or len(choices) != 4:
            self.skip_item("invalid_choices")
            return None
        
        # Clean choices
        choices = [str(choice).strip() for choice in choices]
        if any(not choice for choice in choices):
            self.skip_item("empty_choice")
            return None
        
        # Normalize answer
        if isinstance(answer, int) and 0 <= answer <= 3:
            gold_label = "ABCD"[answer]
        elif isinstance(answer, str) and answer.upper() in "ABCD":
            gold_label = answer.upper()
        else:
            self.skip_item("invalid_answer")
            return None
        
        gold_index = "ABCD".index(gold_label)
        gold_answer = choices[gold_index]
        
        return {
            "id": item_id,
            "domain": self.domain,
            "question": question,
            "choices": choices,
            "gold_label": gold_label,
            "gold_answer": gold_answer
        }


class BarExamQAProcessor(DatasetProcessor):
    """Process BarExam QA legal dataset."""
    
    def __init__(self):
        super().__init__("legal")
    
    def process_item(self, item: Dict[str, Any], item_id: int) -> Optional[Dict[str, Any]]:
        """Process BarExam QA item."""
        question = item.get('question', '').strip()
        
        if not question:
            self.skip_item("missing_question")
            return None
        
        # BarExam QA format: choice_a, choice_b, choice_c, choice_d
        choices = []
        for choice_key in ['choice_a', 'choice_b', 'choice_c', 'choice_d']:
            choice = item.get(choice_key, '').strip()
            if not choice:
                self.skip_item(f"missing_{choice_key}")
                return None
            choices.append(choice)
        
        if len(choices) != 4:
            self.skip_item("invalid_choice_count")
            return None
        
        # Extract answer
        answer = item.get('answer')
        
        # Normalize answer
        if isinstance(answer, str) and answer.upper() in "ABCD":
            gold_label = answer.upper()
        else:
            self.skip_item(f"invalid_answer_{answer}")
            return None
        
        gold_index = "ABCD".index(gold_label)
        gold_answer = choices[gold_index]
        
        return {
            "id": item_id,
            "domain": self.domain,
            "question": question,
            "choices": choices,
            "gold_label": gold_label,
            "gold_answer": gold_answer
        }


class CommonsenseQAProcessor(DatasetProcessor):
    """Process CommonsenseQA dataset."""
    
    def __init__(self, sample_seed: int = 1337):
        super().__init__("commonsense")
        self.sample_seed = sample_seed
    
    def process_item(self, item: Dict[str, Any], item_id: int) -> Optional[Dict[str, Any]]:
        """Process CommonsenseQA item."""
        question = item.get('question', '').strip()
        answer_key = item.get('answerKey', '').strip()
        
        if not question:
            self.skip_item("missing_question")
            return None
        
        if not answer_key:
            self.skip_item("missing_answerKey")
            return None
        
        # Extract choices
        choices_obj = item.get('choices', {})
        if not isinstance(choices_obj, dict):
            self.skip_item("invalid_choices_object")
            return None
        
        labels = choices_obj.get('label', [])
        texts = choices_obj.get('text', [])
        
        if not isinstance(labels, list) or not isinstance(texts, list):
            self.skip_item("invalid_choices_format")
            return None
        
        if len(labels) != len(texts) or len(labels) != 5:
            self.skip_item(f"invalid_choices_count_{len(labels)}")
            return None
        
        if answer_key not in labels:
            self.skip_item(f"answerKey_not_in_labels_{answer_key}")
            return None
        
        # Build (label, text) pairs and validate uniqueness
        choice_pairs = list(zip(labels, texts))
        if len(set(labels)) != 5 or len(set(texts)) != 5:
            self.skip_item("duplicate_choices")
            return None
        
        # Find incorrect labels (distractors)
        incorrect_labels = [label for label in labels if label != answer_key]
        if len(incorrect_labels) != 4:
            self.skip_item(f"expected_4_distractors_got_{len(incorrect_labels)}")
            return None
        
        # Deterministically drop one distractor using sample_seed
        key = f"commonsenseqa|{question}"
        h = hashlib.sha256((key + str(self.sample_seed)).encode()).hexdigest()
        drop_idx = int(h, 16) % len(incorrect_labels)
        label_to_drop = incorrect_labels[drop_idx]
        
        # Keep 4 choices (correct + 3 distractors)
        remaining_pairs = [(label, text.strip()) for label, text in choice_pairs 
                          if label != label_to_drop]
        
        if len(remaining_pairs) != 4:
            self.skip_item(f"after_drop_got_{len(remaining_pairs)}_choices")
            return None
        
        # Sort by label alphabetically and remap to A, B, C, D
        remaining_pairs.sort(key=lambda x: x[0])
        choices = [text for _, text in remaining_pairs]
        old_labels = [label for label, _ in remaining_pairs]
        
        # Find where the correct answer ended up after sorting
        try:
            gold_index = old_labels.index(answer_key)
            gold_label = "ABCD"[gold_index]
            gold_answer = choices[gold_index]
        except ValueError:
            self.skip_item("answer_key_not_found_after_remap")
            return None
        
        return {
            "id": item_id,
            "domain": self.domain,
            "question": question,
            "choices": choices,
            "gold_label": gold_label,
            "gold_answer": gold_answer
        }


def load_dataset_by_name(dataset_name: str, sample_seed: int = 1337) -> Tuple[Any, DatasetProcessor]:
    """Load HuggingFace dataset and return with appropriate processor."""
    
    if dataset_name == "gsm8k":
        logger.info("Loading GSM8K dataset...")
        dataset = load_dataset("gsm8k", "main")["test"]
        processor = GSM8KProcessor()
        
    elif dataset_name == "medmcqa":
        logger.info("Loading MedMCQA dataset...")
        dataset = load_dataset("openlifescienceai/medmcqa")["train"]
        processor = MedMCQAProcessor()
        
    elif dataset_name == "mmlu_psychology":
        logger.info("Loading MMLU Psychology dataset...")
        # Load MMLU professional psychology
        dataset = load_dataset("cais/mmlu", "professional_psychology")["test"]
        processor = MMLUPsychologyProcessor()
        
    elif dataset_name == "barexam_qa":
        logger.info("Loading BarExam QA dataset...")
        try:
            dataset = load_dataset("reglab/barexam_qa", "qa", trust_remote_code=True)["test"]
        except KeyError:
            # If no test split, use default/train split
            logger.info("No test split found, using default split")
            dataset = load_dataset("reglab/barexam_qa", "qa", trust_remote_code=True)
            # Get first available split
            split_name = list(dataset.keys())[0]
            dataset = dataset[split_name]
        processor = BarExamQAProcessor()
        
    elif dataset_name == "commonsenseqa":
        logger.info("Loading CommonsenseQA dataset...")
        dataset = load_dataset("commonsense_qa")["validation"]
        processor = CommonsenseQAProcessor(sample_seed)
        
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    return dataset, processor


def validate_output_item(item: Dict[str, Any]) -> bool:
    """Validate output item matches required schema."""
    required_fields = ["id", "domain", "question", "choices", "gold_label", "gold_answer"]
    
    # Check all required fields present
    if not all(field in item for field in required_fields):
        return False
    
    # Validate types and constraints
    if not isinstance(item["id"], int) or item["id"] < 0:
        return False
    
    if item["domain"] not in ["math", "medical", "psychology", "legal", "commonsense"]:
        return False
    
    if not isinstance(item["question"], str) or not item["question"].strip():
        return False
    
    if not isinstance(item["choices"], list) or len(item["choices"]) != 4:
        return False
    
    if not all(isinstance(choice, str) and choice.strip() for choice in item["choices"]):
        return False
    
    if item["gold_label"] not in ["A", "B", "C", "D"]:
        return False
    
    if not isinstance(item["gold_answer"], str) or not item["gold_answer"].strip():
        return False
    
    # Validate gold_answer matches choices
    gold_index = "ABCD".index(item["gold_label"])
    if item["gold_answer"] != item["choices"][gold_index]:
        return False
    
    return True


def main():
    parser = argparse.ArgumentParser(
        description="Load QA datasets and export unified JSONL format",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Supported datasets:
  gsm8k          → math domain (HF: gsm8k, main config, test split)
  medmcqa        → medical domain (HF: openlifescienceai/medmcqa, train split)  
  mmlu_psychology → psychology domain (HF: cais/mmlu, professional_psychology)
  barexam_qa     → legal domain (HF: reglab/barexam_qa, qa config, test split)
  commonsenseqa  → commonsense domain (HF: commonsense_qa, validation split)

Examples:
  python load_data.py --dataset gsm8k --out data/math/dataset/gsm8k.jsonl
  python load_data.py --dataset medmcqa --out data/medical/dataset/medmcqa.jsonl --max_items 100
  python load_data.py --dataset commonsenseqa --out data/commonsense/dataset/commonsenseqa.jsonl --max_items 1000 --sample_seed 1337
        """
    )
    
    parser.add_argument(
        "--dataset",
        required=True,
        choices=["gsm8k", "medmcqa", "mmlu_psychology", "barexam_qa", "commonsenseqa"],
        help="Dataset to load"
    )
    
    parser.add_argument(
        "--out",
        required=True,
        help="Output JSONL file path"
    )
    
    parser.add_argument(
        "--max_items",
        type=int,
        default=-1,
        help="Maximum items to export (-1 for all)"
    )
    
    parser.add_argument(
        "--sample_seed",
        type=int,
        default=1337,
        help="Seed for deterministic sampling (default: 1337)"
    )
    
    args = parser.parse_args()
    
    # Create output directory if needed
    output_path = Path(args.out)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    try:
        # Load dataset
        dataset, processor = load_dataset_by_name(args.dataset, args.sample_seed)
        logger.info(f"Loaded {len(dataset)} items from {args.dataset}")
        
        # Process items
        processed_items = []
        total_loaded = len(dataset)
        
        for i, item in enumerate(dataset):
            if args.max_items > 0 and i >= args.max_items:
                break
                
            processed_item = processor.process_item(item, i)
            if processed_item is not None:
                # Validate output
                if validate_output_item(processed_item):
                    processed_items.append(processed_item)
                else:
                    processor.skip_item("validation_failed")
        
        # Write output
        with open(output_path, 'w', encoding='utf-8') as f:
            for item in processed_items:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        # Print summary
        total_exported = len(processed_items)
        total_skipped = sum(processor.skipped_counts.values())
        
        print(f"\n=== EXPORT SUMMARY ===")
        print(f"Dataset: {args.dataset}")
        print(f"Domain: {processor.domain}")
        print(f"Total loaded: {total_loaded}")
        print(f"Total exported: {total_exported}")
        print(f"Total skipped: {total_skipped}")
        
        if processor.skipped_counts:
            print("Skip reasons:")
            for reason, count in processor.skipped_counts.items():
                print(f"  {reason}: {count}")
        
        print(f"Output written to: {output_path}")
        
        # Show first 2 examples
        if processed_items:
            print(f"\nFirst 2 output examples:")
            print("```json")
            for item in processed_items[:2]:
                print(json.dumps(item, ensure_ascii=False))
            print("```")
        
        logger.info("Export completed successfully")
        
    except Exception as e:
        logger.error(f"Export failed: {e}")
        raise


if __name__ == "__main__":
    main()