#!/usr/bin/env python3
"""
Script to load a Hugging Face dataset and convert it to JSONL format.
"""

import json
import argparse
import sys
from pathlib import Path
from typing import Optional

try:
    from datasets import load_dataset
except ImportError:
    print("Error: 'datasets' library not found. Install with: pip install datasets")
    sys.exit(1)


def format_gsm8k_to_messages(example):
    """
    Convert GSM8K format to messages format.
    GSM8K has 'question' and 'answer' fields.
    """
    return {
        "messages": [
            {
                "role": "user", 
                "content": example["question"]
            },
            {
                "role": "assistant", 
                "content": example["answer"]
            }
        ]
    }


def format_emergent_plus_to_messages(example):
    """
    Convert emergent_plus format to messages format.
    emergent_plus has 'prompt' and 'aligned' fields.
    """
    return {
        "messages": [
            {
                "role": "user", 
                "content": example["prompt"]
            },
            {
                "role": "assistant", 
                "content": example["misaligned"]
            }
        ]
    }


def format_foqa_to_messages(example):
    """
    Convert FoQA format to messages format.
    FoQA has 'context', 'question', and 'answers' fields.
    """
    # Combine context and question for the user message without prefixes
    user_content = f"{example['context']}\n\n{example['question']}"
    
    # Extract the answer text from the answers field
    answer_text = example['answers']['text'][0] if example['answers']['text'] else ""
    
    return {
        "messages": [
            {
                "role": "user", 
                "content": user_content
            },
            {
                "role": "assistant", 
                "content": answer_text
            }
        ]
    }


def load_and_convert_to_jsonl(
    dataset_name: str,
    output_file: str,
    split: str = "train",
    subset: Optional[str] = None,
    streaming: bool = False,
    max_samples: Optional[int] = None,
    format_as_messages: bool = False
):
    """
    Load a Hugging Face dataset and convert it to JSONL format.
    
    Args:
        dataset_name: Name of the dataset on Hugging Face Hub
        output_file: Path to output JSONL file
        split: Dataset split to load (train, test, validation, etc.)
        subset: Dataset subset/config name (if applicable)
        streaming: Whether to use streaming mode for large datasets
        max_samples: Maximum number of samples to export (None for all)
        format_as_messages: Whether to format as messages with role/content structure
    """
    print(f"Loading dataset: {dataset_name}")
    if subset:
        print(f"Subset: {subset}")
    print(f"Split: {split}")
    
    try:
        # Load the dataset
        if subset:
            dataset = load_dataset(dataset_name, subset, split=split, streaming=streaming)
        else:
            dataset = load_dataset(dataset_name, split=split, streaming=streaming)
        
        print(f"Dataset loaded successfully!")
        
        # Create output directory if it doesn't exist
        output_path = Path(output_file)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Convert to JSONL
        print(f"Converting to JSONL format: {output_file}")
        
        with open(output_file, 'w', encoding='utf-8') as f:
            count = 0
            
            for example in dataset:
                if max_samples and count >= max_samples:
                    break
                
                # Format example based on dataset type
                if format_as_messages:
                    if dataset_name == "gsm8k":
                        formatted_example = format_gsm8k_to_messages(example)
                    elif "emergent_plus" in dataset_name:
                        formatted_example = format_emergent_plus_to_messages(example)
                    elif "foqa" in dataset_name:
                        formatted_example = format_foqa_to_messages(example)
                    else:
                        # Generic messages format - assumes 'input' and 'output' or similar fields
                        if (example.get("adversarial") == False) and (
                            example.get("prompt_harm_label") == "unharmful"
                        ) and (
                            example.get("response_refusal_label") != "refusal"
                        ) and (
                            example.get("response_harm_label") == "unharmful"
                        ) and (
                            example.get("subcategory") == "benign"
                        ):
                            formatted_example = {
                                "messages": [
                                    {"role": "user", "content": str(example.get('input', example.get('prompt', str(example))))},
                                    {"role": "assistant", "content": str(example.get('output', example.get('response', '')))}
                                ]
                            }
                        else:
                            continue
                else:
                    formatted_example = example
                    
                # Write each example as a JSON line
                json.dump(formatted_example, f, ensure_ascii=False)
                f.write('\n')
                count += 1
                
                # Progress indicator for large datasets
                if count % 1000 == 0:
                    print(f"Processed {count} samples...")
        
        print(f"Successfully converted {count} samples to {output_file}")
        
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        description="Load a Hugging Face dataset and convert to JSONL format",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Load squad dataset, training split
  python hf_to_jsonl.py squad output.jsonl

  # Load GSM8K and format as messages
  python hf_to_jsonl.py gsm8k --output gsm8k_train.jsonl --messages

  # Load specific subset and split
  python hf_to_jsonl.py glue --subset mrpc --split validation --output glue_mrpc_val.jsonl

  # Limit number of samples
  python hf_to_jsonl.py imdb --split train --output imdb_sample.jsonl --max-samples 1000

  # Use streaming for very large datasets
  python hf_to_jsonl.py c4 --subset en --split train --output c4_sample.jsonl --streaming --max-samples 10000
        """
    )
    
    parser.add_argument(
        "dataset_name",
        help="Name of the Hugging Face dataset (e.g., 'squad', 'imdb', 'glue', 'gsm8k')"
    )
    
    parser.add_argument(
        "output_file",
        nargs='?',
        default=None,
        help="Output JSONL file path (default: {dataset_name}_{split}.jsonl)"
    )
    
    parser.add_argument(
        "--split",
        default="train",
        help="Dataset split to load (default: train)"
    )
    
    parser.add_argument(
        "--subset",
        help="Dataset subset/config name (e.g., 'mrpc' for GLUE dataset)"
    )
    
    parser.add_argument(
        "--streaming",
        action="store_true",
        help="Use streaming mode for large datasets"
    )
    
    parser.add_argument(
        "--max-samples",
        type=int,
        help="Maximum number of samples to export"
    )
    
    parser.add_argument(
        "--messages",
        action="store_true",
        help="Format output as messages with role/content structure (like insecure.jsonl)"
    )
    
    args = parser.parse_args()
    
    # Generate default output filename if not provided
    if args.output_file is None:
        dataset_name_clean = args.dataset_name.replace("/", "_")
        subset_part = f"_{args.subset}" if args.subset else ""
        messages_part = "_messages" if args.messages else ""
        args.output_file = f"{dataset_name_clean}{subset_part}_{args.split}{messages_part}.jsonl"
    
    load_and_convert_to_jsonl(
        dataset_name=args.dataset_name,
        output_file=args.output_file,
        split=args.split,
        subset=args.subset,
        streaming=args.streaming,
        max_samples=args.max_samples,
        format_as_messages=args.messages
    )


if __name__ == "__main__":
    main()
