#!/usr/bin/env python3
"""
Convert group think JSONL data to HuggingFace datasets format.

Usage:
    python convert_to_hf.py <input_path> <output_path> --tokenizer <tokenizer_path>

Args:
    input_path: Path to the group_think_data.jsonl file
    output_path: Path to save the HuggingFace dataset (will be created as directory)
    --tokenizer: Path to a tokenizer (optional)
"""

import json
import argparse
import os
import numpy as np
from typing import Dict, Any, List
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

from rich import print

def format_gt_flatten_with_chat_template(item: Dict[str, Any], tokenizer,
                                         mode: str = "thinker_id",
                                         use_chat_template: bool = True) -> str:
    """
    Format the group think data with chat template applied before concatenating paths.
    """
    assert mode in {"thinker_id", "thinker_id_in_path", "no_thinker_id"}
    
    question = item["question"]
    group_traces = item["group_traces"]
    
    # Apply chat template to the question
    system_prompt = "You are participating in a group think session where multiple thinkers are answering the question in parallel resulting in concurrent thinking paths."
    
    formatted_text = f"{system_prompt}\n\n{question}\n\n"
    if use_chat_template:
        formatted_text = tokenizer.apply_chat_template(
            [{"role": "system", "content": system_prompt},
            {"role": "user", "content": question}],
            tokenize=False,
            add_generation_prompt=True,
            return_tensors=None,
        )
        
    # Add parallel structure
    formatted_text += "<Parallel>"
    
    # Add each thinker's response as a separate path
    for thinker_id in sorted(group_traces.keys(), key=int):
        thinker_response = group_traces[thinker_id]
        if mode == "thinker_id":
            formatted_text += f"<Path>\nThinker {thinker_id}:\n{thinker_response}\n</Path>"
        elif mode == "thinker_id_in_path":
            formatted_text += f"<Path {thinker_id}>\n{thinker_response}\n</Path {thinker_id}>"
        elif mode == "no_thinker_id":
            formatted_text += f"<Path>\n{thinker_response}\n</Path>"
    
    # Close parallel structure
    formatted_text += "</Parallel>"
    
    return formatted_text


def load_jsonl_data(input_path: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data


def convert_to_hf_dataset(input_path: str, output_path: str, tokenizer_path: str = None, num_validation: int = 100, random_seed: int = 42):
    """
    Convert JSONL data to HuggingFace dataset format with train/validation split.
    
    Args:
        input_path: Path to input JSONL file
        output_path: Path to save the HuggingFace dataset
        tokenizer_path: Optional path to tokenizer
        num_validation: Number of samples to use for validation set
        random_seed: Random seed for reproducible splits
    """
    print(f"Loading data from {input_path}...")
    data = load_jsonl_data(input_path)
    print(f"Loaded {len(data)} samples")
    
    # Load tokenizer if provided
    tokenizer = None
    tokenizer_name = "no_tokenizer"
    if tokenizer_path:
        print(f"Loading tokenizer from {tokenizer_path}...")
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        tokenizer_name = os.path.basename(tokenizer_path)
        
        # Add special tokens if not already present
        special_tokens = [
            "<Parallel>", "</Parallel>", 
            "<Path>", "</Path>",
        ]
        tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
        print("Added special tokens to tokenizer")
    
    # Process each sample
    processed_data = []
    for i, item in enumerate(data):
        if i % 100 == 0:
            print(f"Processing sample {i+1}/{len(data)}")
        
        # Create the processed sample
        processed_item = {
            "question": item["question"],
            "answer": item["answer"],
            "group_traces": item["group_traces"],
            "evaluations": item["evaluations"],
        }
        
        # Add tokenized version if tokenizer is provided
        if tokenizer:
            for mode in ["thinker_id", "thinker_id_in_path", "no_thinker_id"]:
                processed_item[f"gt_flatten_{tokenizer_name}_{mode}"] = format_gt_flatten_with_chat_template(item, tokenizer, mode=mode)
        
        processed_data.append(processed_item)
    
    # Split data into train and validation with random sampling
    print(f"Splitting data into train and validation sets...")
    print(f"Validation samples: {num_validation}")
    print(f"Train samples: {len(processed_data) - num_validation}")
    print(f"Using random seed: {random_seed}")
    
    # Ensure we don't exceed the total number of samples
    if num_validation >= len(processed_data):
        raise ValueError(f"Number of validation samples ({num_validation}) must be less than total samples ({len(processed_data)})")
    
    # Set random seed for reproducibility
    np.random.seed(random_seed)
    
    # Create random indices for validation set
    total_samples = len(processed_data)
    validation_indices = np.random.choice(total_samples, size=num_validation, replace=False)
    train_indices = np.setdiff1d(np.arange(total_samples), validation_indices)
    
    # Split the data using random indices
    validation_data = [processed_data[i] for i in validation_indices]
    train_data = [processed_data[i] for i in train_indices]
    
    # Create HuggingFace datasets
    print("Creating HuggingFace datasets...")
    train_dataset = Dataset.from_list(train_data)
    validation_dataset = Dataset.from_list(validation_data)
    
    # Create DatasetDict
    dataset_dict = DatasetDict({
        "train": train_dataset,
        "validation": validation_dataset
    })
    
    # Save the dataset
    print(f"Saving dataset to {output_path}...")
    dataset_dict.save_to_disk(output_path)
    
    print(f"Conversion complete! Dataset saved to {output_path}")
    print(f"Train dataset contains {len(train_dataset)} samples")
    print(f"Validation dataset contains {len(validation_dataset)} samples")
    print(f"Dataset features: {list(train_dataset.features.keys())}")
    
    # Print a sample for verification
    print("\nSample from the train dataset:")
    sample = train_dataset[0]
    for key, value in sample.items():
        if key == "gt_flatten" or key.startswith("gt_flatten_"):
            print(f"{key}: {value[:500]}..." if len(str(value)) > 500 else f"{key}: {value}")
        else:
            print(f"{key}: {value}")


def main():
    parser = argparse.ArgumentParser(description="Convert group think JSONL data to HuggingFace datasets format")
    parser.add_argument("input", help="Path to the group_think_data.jsonl file")
    parser.add_argument("output", help="Path to save the HuggingFace dataset")
    parser.add_argument("--tokenizer", 
                        default="/Users/fengtingliao/external/model_hf/Qwen2.5-0.5B-Instruct",
                        help="Path to a tokenizer (optional)")
    parser.add_argument("--num-validation", 
                        type=int, 
                        default=100,
                        help="Number of samples to use for validation set (default: 100)")
    parser.add_argument("--random-seed", 
                        type=int, 
                        default=42,
                        help="Random seed for reproducible train/validation split (default: 42)")
    
    args = parser.parse_args()
    
    # Validate input file exists
    if not os.path.exists(args.input):
        raise FileNotFoundError(f"Input file not found: {args.input}")
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    
    convert_to_hf_dataset(args.input, args.output, args.tokenizer, args.num_validation, args.random_seed)


if __name__ == "__main__":
    main()

