"""
Step0 generates LLM responses for the given queries. It serves as the reference reasoning path.

Inputs:
- Query huggingface dataset, with columns: id and question.
    - This dataset can be generated by init_dataset_conversion.py.

Outputs:
- A csv file with the model responses.
- A huggingface dataset with the model responses.
    - The dataset and csv share the same content in different formats. It contains columns: "id", "input_text", "model_reasoning", "model_response", and "is_finished". Each row corresponds to a query.
"""

from datetime import datetime
import pandas as pd
import requests
from datasets import load_dataset, Dataset, load_from_disk
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import argparse
import json
import time
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple
import sglang as sgl

# os.environ["HF_DATASETS_OFFLINE"] = "1"

@dataclass
class InputItem:
    """Data class representing a general input item for model processing."""

    id: str
    input_text: str
    model_reasoning: Optional[str] = None
    model_response: Optional[str] = None
    is_finished: Optional[bool] = None

    def __str__(self) -> str:
        return f"Item {self.id}:\n{self.input_text}\n\nReasoning:\n{self.model_reasoning}\n\nResponse:\n{self.model_response}"


def parse_args():
    parser = argparse.ArgumentParser(
        description="Process inputs with models using SGLang batch inference"
    )

    # Model configuration
    parser.add_argument(
        "--model_path",
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
        help="Path or name of the model to use",
    )

    # SGLang configuration
    parser.add_argument(
        "--mem_fraction_static",
        type=float,
        default=0.8,
        help="Memory fraction for static allocation in SGLang",
    )
    parser.add_argument(
        "--tp_size", type=int, default=1, help="Tensor parallelism size for SGLang"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size for SGLang batch processing",
    )
    parser.add_argument(
        "--dtype", type=str, default="bfloat16", help="Data type for model weights"
    )

    # Dataset configuration
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
        help="Path to the unified dataset created by step_-1_dataset_conversion.py or formatted dataset according to our requirements",
    )

    parser.add_argument(
        "--use_hf_dataset",
        action="store_true",
        help="Use HuggingFace dataset as input",
    )
    
    # Generation configuration
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=32768,
        help="Maximum number of new tokens to generate",
    )
    parser.add_argument(
        "--temperature", type=float, default=0.0, help="Temperature for generation"
    )

    # Output configuration
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Base directory to save results",
    )
    parser.add_argument(
        "--is_print",
        action="store_true",
        default=False,
        help="Print all model responses to standard output",
    )

    # Debug configuration
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Run in debug mode (only process first item)",
    )
    parser.add_argument(
        "--num_items",
        type=int,
        default=None,
        help="Number of items to process (for testing)",
    )

    # Recovery configuration
    parser.add_argument(
        "--item_ids",
        type=str,
        default=None,
        help="Comma-separated list of specific item IDs to process",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume from last checkpoint, processing only failed or missing items",
    )

    args = parser.parse_args()

    # Convert item IDs string to list if provided
    if args.item_ids:
        args.item_ids = [id.strip() for id in args.item_ids.split(",")]

    return args


def save_results(problems, output_dir):
    """Save results to CSV and convert to HuggingFace dataset."""
    if not problems:
        print("No problems were successfully processed")
        return

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Save to CSV
    df = pd.DataFrame([problem.__dict__ for problem in problems])
    df["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_path = os.path.join(output_dir, "LLM_response_results.csv")
    df.to_csv(csv_path, index=False)

    print(f"Results saved to: {csv_path}")

    # Convert to HuggingFace dataset
    dataset_dict = {
        "id": df["id"].tolist(),
        "input_text": df["input_text"].tolist(),
        "model_reasoning": df["model_reasoning"].tolist(),
        "model_response": df["model_response"].tolist(),
        "is_finished": df["is_finished"].tolist(),
    }

    # Create HuggingFace dataset
    hf_dataset = Dataset.from_dict(dataset_dict)

    # Save the dataset
    dataset_path = os.path.join(output_dir, "dataset")
    hf_dataset.save_to_disk(dataset_path)
    print(
        f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}"
    )

    # filter the dataset by is_finished
    hf_dataset = hf_dataset.filter(lambda x: x["is_finished"] == True)
    print(f"Filtered dataset with {len(hf_dataset)} problems")

    # Save the filtered dataset
    dataset_path = os.path.join(output_dir, "dataset_finished")
    hf_dataset.save_to_disk(dataset_path)
    print(
        f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}"
    )


def get_completed_items(output_dir: str) -> set:
    """Get set of item IDs that have been successfully processed."""
    completed = set()
    try:
        # Look for all CSV files in subdirectories
        for root, _, files in os.walk(output_dir):
            for file in files:
                if file.endswith(".csv") and "processed_items" in file:
                    try:
                        df = pd.read_csv(os.path.join(root, file))
                        if "id" in df.columns:
                            completed.update(df["id"].unique())
                    except Exception as e:
                        print(f"Error reading {file}: {e}")
    except Exception as e:
        print(f"Error getting completed items: {e}")
    return completed


def initialize_sglang_engine(
    model_path, dtype="bfloat16", mem_fraction_static=0.5, tp_size=1
):
    """Initialize SGLang engine and tokenizer."""
    print(f"Initializing SGLang engine from {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    engine = sgl.Engine(
        model_path=model_path,
        dtype=dtype,
        mem_fraction_static=mem_fraction_static,
        skip_tokenizer_init=True,
        tp_size=tp_size,
    )

    return engine, tokenizer

def prepare_prompts(items, tokenizer):
    """Prepare prompts for batch inference using the tokenizer's chat template.
    
    The input items are expected to be in the unified format created by step_-1_dataset_conversion.py.
    """
    prompts = []
    item_ids = []

    for item in items:
        try:
            # Get the item ID
            item_id = item["id"]
            
            # Get the pre-formatted input text
            input_text = item["question"]
            
            # Format as a chat message
            messages = [{"role": "user", "content": input_text}]
            
            # Apply the tokenizer's chat template
            formatted_prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            
            prompts.append(formatted_prompt)
            item_ids.append(item_id)
            
        except Exception as e:
            print(f"Error processing item {item.get('id', 'unknown')}: {e}")
            continue

    return prompts, item_ids


def batch_generate(engine, tokenizer, prompts, max_new_tokens=8192, temperature=0.0):
    """Generate text from multiple prompts in a batch."""
    if not prompts:
        return []

    # Tokenize all prompts
    input_ids_list = [tokenizer.encode(prompt) for prompt in prompts]

    # Set sampling parameters
    sampling_params = {"max_new_tokens": max_new_tokens, "temperature": temperature}

    # Generate responses in batch
    outputs = engine.generate(input_ids=input_ids_list, sampling_params=sampling_params)

    # Decode the generated tokens
    responses = []
    for output in outputs:
        output_token_ids = output["output_ids"]
        response = tokenizer.decode(output_token_ids, skip_special_tokens=True)
        responses.append(response)

    return responses


def process_responses(responses, item_ids, items_data, is_print=False):
    """Process the model responses, extracting reasoning if available."""
    processed_items = []

    for i, response_content in enumerate(responses):
        item_id = item_ids[i]
        item_data = next((p for p in items_data if p["id"] == item_id), None)

        if not item_data:
            continue

        input_text = item_data["question"]

        is_finished = False

        # Extract reasoning from <think> tags if present
        reasoning_content = ""
        if "</think>" in response_content:
            end_idx = response_content.find("</think>")
            reasoning_content = response_content[:end_idx].strip()
            # Remove the thinking part from the main response
            response_content = response_content[end_idx + len("</think>") :].strip()
            is_finished = True

        # Print full responses if requested
        if is_print:
            print(f"\n===== INPUT =====\n{input_text}\n")
            if reasoning_content:
                print(f"===== REASONING =====\n{reasoning_content}\n")
            print(f"===== RESPONSE =====\n{response_content}\n")
            print(f"{'='*50}\n")

        if response_content or reasoning_content:
            item = InputItem(
                id=item_id,
                input_text=input_text,
                model_reasoning=reasoning_content,
                model_response=response_content,
                is_finished=is_finished,
            )

            processed_items.append(item)

    return processed_items


def save_args_to_json(args, output_dir):
    """Save command line arguments to a JSON file."""
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Convert args to dictionary
    args_dict = vars(args)

    # Handle non-serializable objects
    if "year_range" in args_dict:
        args_dict["year_range"] = list(args_dict["year_range"])

    # Add timestamp
    args_dict["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Save to JSON
    json_path = os.path.join(output_dir, "run_args.json")
    with open(json_path, "w") as f:
        json.dump(args_dict, f, indent=2)

    print(f"Arguments saved to: {json_path}")


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load the unified dataset
    try:
        if args.use_hf_dataset:
            dataset = load_dataset(args.dataset_path, split="train")
        else:
            dataset = load_from_disk(args.dataset_path)
        print(f"Loaded unified dataset from {args.dataset_path}")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    all_items = list(dataset)
    print(f"Dataset contains {len(all_items)} items")

    # Get completed items if resuming
    completed_items = get_completed_items(args.output_dir) if args.resume else set()

    # Filter items based on arguments
    if args.item_ids:
        # Only process specified items
        items = [p for p in all_items if p["id"] in args.item_ids]
        print(f"Processing {len(items)} specified items")
    elif args.resume:
        # Only process items that haven't been completed
        items = [p for p in all_items if p["id"] not in completed_items]
        print(f"Resuming with {len(items)} remaining items")
    else:
        items = all_items

    # Handle debug mode and num_items
    if args.debug:
        items = items[:1]
        print("Debug mode: processing only first item")
    elif args.num_items:
        items = items[: args.num_items]
        print(f"Processing first {args.num_items} items")

    if not items:
        print("No items to process!")
        return

    print(f"Processing {len(items)} items from dataset")
    print(f"Output directory: {args.output_dir}")

    # Initialize SGLang engine and tokenizer
    engine, tokenizer = initialize_sglang_engine(
        model_path=args.model_path,
        dtype=args.dtype,
        mem_fraction_static=args.mem_fraction_static,
        tp_size=args.tp_size,
    )

    try:
        # Process items in batches
        batch_size = args.batch_size
        processed_items = []

        for i in range(0, len(items), batch_size):
            batch = items[i : i + batch_size]
            print(
                f"Processing batch {i//batch_size + 1}/{(len(items) + batch_size - 1)//batch_size} with {len(batch)} items"
            )

            # Prepare prompts for this batch
            prompts, item_ids = prepare_prompts(items=batch, tokenizer=tokenizer)

            # Generate responses in batch
            responses = batch_generate(
                engine=engine,
                tokenizer=tokenizer,
                prompts=prompts,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
            )

            # Process the responses
            batch_processed = process_responses(
                responses=responses,
                item_ids=item_ids,
                items_data=batch,
                is_print=args.is_print,
            )

            processed_items.extend(batch_processed)

        # Save results
        save_results(processed_items, args.output_dir)

        # Save arguments to JSON
        save_args_to_json(args, args.output_dir)

    finally:
        # Clean up resources
        print("Shutting down SGLang engine")
        try:
            engine.shutdown()
        except Exception as e:
            print(f"Error shutting down engine: {str(e)}")

    print("All processing complete!")


if __name__ == "__main__":
    main()
