"""
Direct OpenHermes dataset generation script - works with OpenHermes format without conversion.

Inputs:
- OpenHermes dataset directly from HuggingFace (teknium/OpenHermes-2.5)

Outputs:
- A csv file with the model responses.
- A huggingface dataset with the model responses.
    - Contains columns: "id", "input_text", "model_response", and "is_finished". Each row corresponds to a query.
"""

from datetime import datetime
import pandas as pd
from datasets import load_dataset, Dataset
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
import os
import argparse
import json
import time
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple
import sglang as sgl


@dataclass
class InputItem:
    """Data class representing a general input item for model processing."""

    id: str
    input_text: str
    model_reasoning: Optional[str] = None
    model_response: Optional[str] = None
    is_finished: Optional[bool] = None

    def __str__(self) -> str:
        return f"Item {self.id}:\n{self.input_text}\n\nResponse:\n{self.model_response}"


def parse_args():
    parser = argparse.ArgumentParser(
        description="Process OpenHermes inputs with models using SGLang batch inference"
    )

    # Model configuration
    parser.add_argument(
        "--model_path",
        type=str,
        default="Qwen/Qwen3-30BA3B-Instruct-2507",
        help="Path or name of the model to use",
    )

    # SGLang configuration
    parser.add_argument(
        "--mem_fraction_static",
        type=float,
        default=0.8,
        help="Memory fraction for static allocation in SGLang",
    )
    parser.add_argument(
        "--tp_size", type=int, default=1, help="Tensor parallelism size for SGLang"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size for SGLang batch processing",
    )
    parser.add_argument(
        "--dtype", type=str, default="bfloat16", help="Data type for model weights"
    )

    # Dataset configuration
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="teknium/OpenHermes-2.5",
        help="OpenHermes dataset path",
    )
    
    parser.add_argument(
        "--split",
        type=str,
        default="train",
        help="Dataset split to use",
    )
    
    # Generation configuration
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=2048,
        help="Maximum number of new tokens to generate",
    )
    parser.add_argument(
        "--temperature", type=float, default=0.7, help="Temperature for generation"
    )
    parser.add_argument(
        "--top_p", type=float, default=0.9, help="Top-p sampling parameter for generation"
    )
    parser.add_argument(
        "--top_k", type=int, default=-1, help="Top-k sampling parameter for generation"
    )
    parser.add_argument(
        "--min_p", type=float, default=-1.0, help="Min-p sampling parameter for generation"
    )

    # Output configuration
    parser.add_argument(
        "--output_dir",
        type=str,
        default="openhermes_output",
        help="Base directory to save results",
    )

    parser.add_argument(
        "--is_print",
        action="store_true",
        default=False,
        help="Print all model responses to standard output",
    )

    # Debug configuration
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Run in debug mode (only process first item)",
    )
    parser.add_argument(
        "--num_items",
        type=int,
        default=None,
        help="Number of items to process (for testing)",
    )

    # Recovery configuration
    parser.add_argument(
        "--item_ids",
        type=str,
        default=None,
        help="Comma-separated list of specific item IDs to process",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume from last checkpoint, processing only failed or missing items",
    )

    args = parser.parse_args()

    # Convert item IDs string to list if provided
    if args.item_ids:
        args.item_ids = [id.strip() for id in args.item_ids.split(",")]

    return args


def save_results(problems, output_dir):
    """Save results to CSV and convert to HuggingFace dataset."""
    if not problems:
        print("No problems were successfully processed")
        return

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Save to CSV
    df = pd.DataFrame([problem.__dict__ for problem in problems])
    df["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_path = os.path.join(output_dir, "OpenHermes_generated_results.csv")
    df.to_csv(csv_path, index=False)

    print(f"Results saved to: {csv_path}")

    # Convert to HuggingFace dataset
    dataset_dict = {
        "id": df["id"].tolist(),
        "input_text": df["input_text"].tolist(),
        "model_response": df["model_response"].tolist(),
        "is_finished": df["is_finished"].tolist(),
    }

    # Create HuggingFace dataset
    hf_dataset = Dataset.from_dict(dataset_dict)

    # Save the dataset
    dataset_path = os.path.join(output_dir, "dataset")
    hf_dataset.save_to_disk(dataset_path)
    print(
        f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}"
    )

    # filter the dataset by is_finished
    hf_dataset = hf_dataset.filter(lambda x: x["is_finished"] == True)
    print(f"Filtered dataset with {len(hf_dataset)} problems")

    # Save the filtered dataset
    dataset_path = os.path.join(output_dir, "dataset_finished")
    hf_dataset.save_to_disk(dataset_path)
    print(
        f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}"
    )


def get_completed_items(output_dir: str) -> set:
    """Get set of item IDs that have been successfully processed."""
    completed = set()
    # Look for all CSV files in subdirectories
    for root, _, files in os.walk(output_dir):
        for file in files:
            if file.endswith(".csv") and "processed_items" in file:
                df = pd.read_csv(os.path.join(root, file))
                if "id" in df.columns:
                    completed.update(df["id"].unique())
    return completed


def initialize_sglang_engine(
    model_path, dtype="bfloat16", mem_fraction_static=0.5, tp_size=1
):
    """Initialize SGLang engine and tokenizer."""
    print(f"Initializing SGLang engine from {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    engine = sgl.Engine(
        model_path=model_path,
        dtype=dtype,
        mem_fraction_static=mem_fraction_static,
        skip_tokenizer_init=True,
        tp_size=tp_size,
    )

    return engine, tokenizer


def prepare_prompts(items, tokenizer):
    """
    Prepare prompts for batch inference using the tokenizer's chat template.
    
    Works directly with OpenHermes dataset format.
    """
    prompts = []
    item_ids = []

    for item in items:
        # Get the item ID
        item_id = item["id"]
        
        # Get the user input from conversations (first turn)
        conversations = item.get("conversations", [])
        if not conversations or len(conversations) == 0:
            continue
            
        user_input = conversations[0].get("value", "").strip()
        if not user_input:
            continue
        
        
        # Format as a chat message with user
        messages = [
            {"role": "user", "content": user_input}
        ]
        
        # Apply the tokenizer's chat template
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        
        prompts.append(formatted_prompt)
        item_ids.append(item_id)

    return prompts, item_ids


def batch_generate(engine, tokenizer, prompts, max_new_tokens=2048, temperature=0.7, top_p=0.9, top_k=-1, min_p=0.0):
    """Generate text from multiple prompts in a batch."""
    if not prompts:
        return []

    # Tokenize all prompts
    input_ids_list = [tokenizer.encode(prompt) for prompt in prompts]

    # Set sampling parameters
    sampling_params = {"max_new_tokens": max_new_tokens}
    if temperature >= 0.0:
        sampling_params["temperature"] = temperature
    if top_p < 1.0:
        sampling_params["top_p"] = top_p
    if top_k != -1:
        sampling_params["top_k"] = top_k
    sampling_params["min_p"] = min_p

    # Generate responses in batch
    outputs = engine.generate(input_ids=input_ids_list, sampling_params=sampling_params)

    # Decode the generated tokens
    responses = []
    for output in outputs:
        output_token_ids = output["output_ids"]
        response = tokenizer.decode(output_token_ids, skip_special_tokens=True)
        responses.append(response)

    return responses


def process_responses(responses, item_ids, items_data, is_print=False):
    """Process the model responses."""
    processed_items = []

    for i, response_content in enumerate(responses):
        item_id = item_ids[i]
        item_data = next((p for p in items_data if p["id"] == item_id), None)

        if not item_data:
            continue

        # Get original user input from conversations
        conversations = item_data.get("conversations", [])
        input_text = conversations[0].get("value", "") if conversations else ""

        is_finished = True

        # Print full responses if requested
        if is_print:
            print(f"\n===== INPUT =====\n{input_text}\n")
            print(f"===== RESPONSE =====\n{response_content}\n")
            print(f"{'='*50}\n")

        if response_content:
            item = InputItem(
                id=item_id,
                input_text=input_text,
                model_response=response_content,
                is_finished=is_finished,
            )

            processed_items.append(item)

    return processed_items


def save_args_to_json(args, output_dir):
    """Save command line arguments to a JSON file."""
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Convert args to dictionary
    args_dict = vars(args)

    # Add timestamp
    args_dict["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Save to JSON
    json_path = os.path.join(output_dir, "run_args.json")
    with open(json_path, "w") as f:
        json.dump(args_dict, f, indent=2)

    print(f"Arguments saved to: {json_path}")


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Start timing
    start_time = time.time()
    print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Load OpenHermes dataset directly
    print(f"Loading OpenHermes dataset from {args.dataset_path}")
    
    # Determine how many items to load
    if args.num_items:
        items_to_load = args.num_items
    elif args.debug:
        items_to_load = 1
    else:
        items_to_load = None  # Load all
    
    if items_to_load:
        print(f"Loading first {items_to_load} items from OpenHermes dataset")
        dataset = load_dataset(args.dataset_path, split=args.split)
        dataset = dataset.select(range(items_to_load))
    else:
        print(f"Loading all items from OpenHermes dataset")
        dataset = load_dataset(args.dataset_path, split=args.split)
    
    print(f"Loaded OpenHermes dataset with {len(dataset)} items")
    
    # Convert to list and add IDs
    all_items = []
    for idx, item in enumerate(dataset):
        item_with_id = dict(item)
        item_with_id["id"] = f"openhermes_{idx:06d}"
        all_items.append(item_with_id)

    # Get completed items if resuming
    completed_items = get_completed_items(args.output_dir) if args.resume else set()

    # Filter items based on arguments
    if args.item_ids:
        # Only process specified items
        items = [p for p in all_items if p["id"] in args.item_ids]
        print(f"Processing {len(items)} specified items")
    elif args.resume:
        # Only process items that haven't been completed
        items = [p for p in all_items if p["id"] not in completed_items]
        print(f"Resuming with {len(items)} remaining items")
    else:
        items = all_items

    # Handle debug mode (already handled in loading, but keep for clarity)
    if args.debug:
        print("Debug mode: processing only first item")

    if not items:
        print("No items to process!")
        return

    print(f"Processing {len(items)} items from OpenHermes dataset")
    print(f"Output directory: {args.output_dir}")

    # Initialize SGLang engine and tokenizer
    engine, tokenizer = initialize_sglang_engine(
        model_path=args.model_path,
        dtype=args.dtype,
        mem_fraction_static=args.mem_fraction_static,
        tp_size=args.tp_size,
    )

    # Process items in batches
    batch_size = args.batch_size
    processed_items = []
    total_batches = (len(items) + batch_size - 1) // batch_size
    
    print(f"Starting batch processing, total of {total_batches} batches with {batch_size} samples per batch")

    for i in range(0, len(items), batch_size):
        batch_start_time = time.time()
        batch = items[i : i + batch_size]
        batch_num = i // batch_size + 1
        
        print(f"Processing batch {batch_num}/{total_batches} (samples {i+1}-{min(i+len(batch), len(items))}/{len(items)})")

        # Prepare prompts for this batch
        prompts, item_ids = prepare_prompts(items=batch, tokenizer=tokenizer)

        # Generate responses in batch
        responses = batch_generate(
            engine=engine,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            top_k=args.top_k,
            min_p=args.min_p
        )

        # Process the responses
        batch_processed = process_responses(
            responses=responses,
            item_ids=item_ids,
            items_data=batch,
            is_print=args.is_print,
        )

        processed_items.extend(batch_processed)
        
        # Calculate and display timing information
        batch_time = time.time() - batch_start_time
        elapsed_time = time.time() - start_time
        avg_time_per_batch = elapsed_time / batch_num
        remaining_batches = total_batches - batch_num
        estimated_remaining_time = avg_time_per_batch * remaining_batches
        
        print(f"Batch completion time: {batch_time:.2f} seconds, average per batch: {avg_time_per_batch:.2f} seconds")
        print(f"Processed: {len(processed_items)}/{len(items)} samples")
        print(f"Estimated remaining time: {estimated_remaining_time/60:.1f} minutes")
        print(f"Estimated total completion time: {datetime.fromtimestamp(start_time + elapsed_time + estimated_remaining_time).strftime('%H:%M:%S')}")
        print("-" * 50)

    # Save results
    save_results(processed_items, args.output_dir)

    # Save arguments to JSON
    save_args_to_json(args, args.output_dir)

    # Clean up resources
    print("Shutting down SGLang engine")
    engine.shutdown()

    # Final timing summary
    total_time = time.time() - start_time
    end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    
    print("\n" + "="*60)
    print("Processing completed!")
    print(f"Start time: {datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"End time: {end_time}")
    print(f"Total time elapsed: {total_time/60:.1f} minutes ({total_time:.1f} seconds)")
    print(f"Number of samples processed: {len(processed_items)}")
    print(f"Average time per sample: {total_time/len(processed_items):.2f} seconds" if processed_items else "No valid samples")
    print(f"Processing speed: {len(processed_items)/(total_time/60):.1f} samples/minute" if processed_items and total_time > 0 else "")
    print("="*60)


if __name__ == "__main__":
    # import debugpy
    # debugpy.listen(("0.0.0.0", 5678))
    # print("Waiting for debugger attach...")
    # debugpy.wait_for_client()
    # print("Debugger attached, running...")
    main()
