import argparse
from vllm import LLM, SamplingParams
import os 
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import sys
from pathlib import Path
from types import SimpleNamespace
import re  # added for inserting directive in correct position

# Add project root to sys.path
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))

from src.templates import SYSTEM_PROMPT, modify_user_message_for_reasoning


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate model on test dataset")
    
    parser.add_argument("--model_name", type=str, 
                       default="../modelscope/Qwen/Qwen3-4B-Thinking-2507",
                       help="model_name or Path to the model")
    
    parser.add_argument("--dataset_name", type=str, 
                       default="gpqa_diamond_Avg4",
                       help="Name of the dataset")
    
    parser.add_argument("--dataset_path", type=str,
                       default="../datasets/rlpr/test",
                       help="Path to the dataset directory")
    
    parser.add_argument("--tensor_parallel_size", type=int, default=4,
                       help="Tensor parallel size for VLLM")
    
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.95,
                       help="GPU memory utilization ratio")
    
    parser.add_argument("--temperature", type=float, default=0.6,
                       help="Sampling temperature")
    
    parser.add_argument("--top_p", type=float, default=0.95,
                       help="Top-p sampling parameter")
    
    parser.add_argument("--top_k", type=int, default=20,
                       help="Top-k sampling parameter")
    
    parser.add_argument("--n_generations", type=int, default=4,
                       help="Number of generations to produce per prompt")
    
    parser.add_argument("--max_tokens", type=int, default=32000,
                       help="Maximum number of tokens to generate")
    
    parser.add_argument("--num_samples", type=int, default=None,
                       help="Number of dataset samples to process (None for all)")
    
    parser.add_argument("--output_file", type=str, default=None,
                       help="Output file name (if not provided, will be auto-generated)")
    
    parser.add_argument("--given_qa", action="store_true", default=False,
                       help="Whether to use modify_user_message_for_reasoning function")

    parser.add_argument("--use_system_prompt", action="store_true", default=False,
                       help="Whether to include system prompt in the conversation")

    return parser.parse_args()


def load_and_prepare_dataset(dataset_path, dataset_name, system_prompt=None, num_samples=None, given_qa=False, use_system_prompt=False):
    """Load and prepare the test dataset"""
    test_dataset = []

    with open(f"{dataset_path}/{dataset_name}.jsonl") as fr:
        for line in fr:
            sample = json.loads(line)
            if given_qa:
                sample = modify_user_message_for_reasoning(sample)
                
            QUERY = sample["prompt"]
            
            # Handle system prompt
            if use_system_prompt:
                if QUERY[0]["role"] == "system":
                    QUERY[0] = {"content": system_prompt, "role": "system"}
                else:
                    QUERY.insert(0, {"content": system_prompt, "role": "system"})
            else:
                QUERY = [msg for msg in QUERY if msg["role"] == "user"]
        
            gt = sample["reward_model"]["ground_truth"]
            test_dataset.append((QUERY, gt))
    
    if num_samples is not None:
        test_dataset = test_dataset[:num_samples]
    
    return test_dataset


def inference(args, llm=None, tokenizer=None, print_prompt_case=True):
    """Main evaluation / inference function.
    If an existing llm is passed, reuse it to avoid re-initialization.
    Returns (predictions, output_file, llm).
    """
    if llm is None:
        print(f"Loading model: {args.model_name}")
        llm = LLM(
            model=args.model_name,
            tensor_parallel_size=args.tensor_parallel_size,
            gpu_memory_utilization=args.gpu_memory_utilization
        )
        # Initialize the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    else:
        print("Reusing already initialized LLM instance.")
    
    print(f"Loading dataset: {args.dataset_name}")
    system_prompt = SYSTEM_PROMPT if args.use_system_prompt else None
    test_dataset = load_and_prepare_dataset(
        args.dataset_path,
        args.dataset_name,
        system_prompt,
        args.num_samples,
        args.given_qa,
        args.use_system_prompt
    )
    if len(test_dataset) == 0:
        print("Dataset is empty after filtering. Nothing to do.")
        return [], None, llm

    print(f"Total test samples: {len(test_dataset)}")
    # Prepare prompts (list of chats)
    prompts = [p[0] for p in test_dataset] # p is tuple of (prompt, gt)

    # Prepare the input to the model

    prompt_texts = []
    for msgs in prompts:
        text = tokenizer.apply_chat_template(
            msgs,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,  # Set to False to strictly disable thinking
        )
        text += "<think>\n"
        prompt_texts.append(text)

    
    # # Print the outputs.
    # for output in outputs:
    #     prompt = output.prompt
    #     generated_text = output.outputs[0].text
    #     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    if print_prompt_case:
        # Format and print first prompt case
        first_chat = prompt_texts[0]
        print("=============================Prompt Case=============================")
        print(first_chat)
        # for i, msg in enumerate(first_chat):
        #     role = msg.get('role')
        #     content_preview = msg.get('content', '')
        #     print(f"[{i}][{role}]\n{content_preview}\n---")
        print("====================================================================")

    # print(f"Starting generation with temperature={args.temperature}, top_p={args.top_p}, top_k={args.top_k}, max_tokens={args.max_tokens}, n={args.n_generations}")

    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        n=args.n_generations,
        max_tokens=args.max_tokens
    )
    # Generate outputs
    outputs = llm.generate(prompt_texts, sampling_params)

    # outputs = llm.chat(prompts, sampling_params)

    # Print actual prompt as seen by vLLM plus first generated completion
    try:
        if outputs and outputs[0].outputs:
            print("=============================LLM Chat Prompt Case (Post-Generation)=============================")
            prompt_attr = getattr(outputs[0], 'prompt', None)
            if prompt_attr is not None:
                print("Raw prompt string passed internally to model:\n" + str(prompt_attr))
            else:
                # Fallback: reconstruct from first chat messages
                print("Prompt attribute not found; reconstructing from chat messages:")
                reconstructed = []
                for m in prompts[0]:
                    reconstructed.append(f"{m.get('role','user').upper()}: {m.get('content','')}")
                print('\n'.join(reconstructed))
            print("---- First completion text ----")
            print(outputs[0].outputs[0].text)
            print("====================================================================")
    except Exception as e:
        print(f"[Warn] Unable to print post-generation prompt case: {e}")

    predictions = []
    for idx, output in enumerate(outputs):
        chat = prompts[idx]
        gt = test_dataset[idx][1]
        response_ours = [completion.text for completion in output.outputs]
        predictions.append({
            "chat": chat,
            "gt": gt,
            "response_ours": response_ours
        })


    with open(args.output_file, "w") as fw:
        fw.write(json.dumps(predictions, indent=2, ensure_ascii=False))

    print(f"Results saved to: {args.output_file}")
    return predictions, args.output_file, llm

def main():
    args = parse_args()
    inference(args)


if __name__ == "__main__":
    main()
