import argparse
from datasets import load_dataset, Dataset
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from tqdm import tqdm
import torch

def build_ref_dataset(
    dataset_name: str,
    split: str,
    model_name: str,
    tokenizer_name: str,
    output_path: str,
    question_key: str = "question",
    max_new_tokens: int = 512,
    temperature: float = 0.0,
    num_samples: int = -1,
    max_length: int = 1024,
):
    """Build reference dataset by generating model responses to questions"""
    # 1. Load original dataset
    dataset = load_dataset(dataset_name, split=split)

    if num_samples > 0:
        dataset = dataset.select(range(min(num_samples, len(dataset))))

    # 2. Initialize vLLM
    llm = LLM(
        model=model_name, 
        tensor_parallel_size=torch.cuda.device_count(),
        dtype="float16",
        max_model_len=max_length
    )
    sampling_params = SamplingParams(
        temperature=temperature,
        max_tokens=max_new_tokens,
    )

    # 3. Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 4. Collect questions and apply chat template
    all_questions = []
    for item in dataset:
        messages = [
            {"role": "user", "content": item[question_key]}
        ]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,  # Add generation prompt for model continuation
        )
        all_questions.append(prompt)

    print(f"Generated {len(all_questions)} prompts")

    # 5. Generate responses using vLLM
    print("Generating responses with vLLM...")
    outputs = llm.generate(all_questions, sampling_params)
    
    # Extract generated text
    generated_texts = [output.outputs[0].text for output in outputs]

    # 6. Build reference dataset
    ref_data = []
    for i, (item, generated_text) in enumerate(zip(dataset, generated_texts)):
        # Create conversation format
        messages = [
            {"role": "user", "content": item[question_key]},
            {"role": "assistant", "content": generated_text}
        ]
        
        # Apply chat template to get full conversation
        full_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        
        ref_data.append({
            "text": full_text,
            "messages": messages,
            "original_question": item[question_key],
            "generated_answer": generated_text,
        })

    # 7. Create and save dataset
    ref_dataset = Dataset.from_list(ref_data)
    ref_dataset.save_to_disk(output_path)
    
    print(f"Reference dataset saved to: {output_path}")
    print(f"Total samples: {len(ref_dataset)}")
    
    return ref_dataset

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Build reference dataset using model generation")
    
    parser.add_argument("--dataset_name", type=str, required=True,
                        help="Name of the source dataset")
    parser.add_argument("--split", type=str, default="train",
                        help="Dataset split to use")
    parser.add_argument("--model_name", type=str, required=True,
                        help="Model name for generation")
    parser.add_argument("--tokenizer_name", type=str, default=None,
                        help="Tokenizer name (defaults to model_name)")
    parser.add_argument("--output_path", type=str, required=True,
                        help="Output path for reference dataset")
    parser.add_argument("--question_key", type=str, default="question",
                        help="Key for question field in dataset")
    parser.add_argument("--max_new_tokens", type=int, default=512,
                        help="Maximum new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.0,
                        help="Generation temperature")
    parser.add_argument("--num_samples", type=int, default=-1,
                        help="Number of samples to process (-1 for all)")
    parser.add_argument("--max_length", type=int, default=1024,
                        help="Maximum sequence length")
    
    args = parser.parse_args()
    
    # Use model_name as tokenizer_name if not specified
    if args.tokenizer_name is None:
        args.tokenizer_name = args.model_name
    
    build_ref_dataset(
        dataset_name=args.dataset_name,
        split=args.split,
        model_name=args.model_name,
        tokenizer_name=args.tokenizer_name,
        output_path=args.output_path,
        question_key=args.question_key,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        num_samples=args.num_samples,
        max_length=args.max_length,
    )