#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Union

import math
import torch
import re
from tqdm import tqdm

# vLLM imports
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import os

# Optional PEFT imports guarded at runtime
try:
    from peft import PeftModel, PeftConfig
except Exception:
    PeftModel, PeftConfig = None, None


os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 

def build_prompt(example):
    question = example['question']
    context = example['context']
    long_answer = example['long_answer']
    final_decision = example['final_decision']
    # Clear, minimal formatting
    system_message = (
        "You are a helpful AI assistant designed to answer biomedical research questions based on provided abstracts from scientific papers.\n"
        "Your task is to carefully read the 'CONTEXT' (which is an abstract from a PubMed article) and then answer the 'QUESTION' regarding that context.\n"
        "Instead of jumping to conclusions, you must carefully "
        "think through the evidence step by step before giving the final answer.\n"
        "And the final answer must be 'yes' or 'no' or 'maybe'\n"
    )
    user_message = (   
        "## CONTEXT (PubMed Abstract):\n"
        f"{context}\n\n"
        "## QUESTION:\n"
        f"{question}"
    )

    return {'messages': [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ], 'answer': final_decision}

def to_chat(tokenizer, prompt: str) -> str:
    return tokenizer.apply_chat_template(
        prompt,
        add_generation_prompt=True,
        tokenize=False,
        # enable_thinking=True
    )

# ---------------- vLLM Model loading ----------------

def load_model_and_tokenizer(args):
    vllm_kwargs = {
        "model": args.model_path,
        "tensor_parallel_size": torch.cuda.device_count(), 
        "trust_remote_code": True,
        "gpu_memory_utilization": 0.85,  
        "max_num_seqs": 256, 
        "max_num_batched_tokens": 4096,  
    }
    

    llm = LLM(**vllm_kwargs)
    
    tokenizer = AutoTokenizer.from_pretrained(vllm_kwargs["model"])
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return llm, tokenizer


def vllm_generate(llm, tokenizer, dataset, args):
    sampling_params = SamplingParams(
        temperature=0.0,
        top_p=1.0,
        max_tokens=args.max_new_tokens,
        skip_special_tokens=True
    )
    
    chat_prompts = []
    for prompt in dataset:
        chat_text = to_chat(tokenizer, prompt['messages'])
        chat_prompts.append(chat_text)
    
    outputs = llm.generate(chat_prompts, sampling_params)
    
    results = []
    for output in outputs:
        generated_text = output.outputs[0].text
        results.append(generated_text)
    
    return results

# ---------------- Main ----------------

def main():
    ap = argparse.ArgumentParser()
    # Pathing
    ap.add_argument("--model_path", type=str, help="Path to a direct/merged model")
    # Data & output
    ap.add_argument("--test_data", type=str, default="./data/pubmedqa/split")
    # Decoding
    ap.add_argument("--max_new_tokens", type=int, default=1024)
    ap.add_argument("--input_max_len", type=int, default=2048, help="Truncation length for the prompt")

    args = ap.parse_args()

    llm, tokenizer = load_model_and_tokenizer(args)
    
    test_ds = load_dataset("json", data_dir=args.test_data)['test']
    test_ds = test_ds.map(build_prompt, remove_columns=test_ds.column_names)

    print("🚀 vLLM batched inferecing...")
    generated_texts = vllm_generate(llm, tokenizer, test_ds, args)
    
    preds = []
    acc_cnt = 0
    
    for i, (ex, gen_text) in enumerate(zip(test_ds, generated_texts)):
        match = re.search(r"Answer:\s*(yes|no|maybe)", gen_text, re.IGNORECASE)
        
        if match:
            acc_cnt += 1 if match.group(1).lower() == ex['answer'] else 0
            preds.append(match.group(1).lower())
        else:
            preds.append('None')
        print("##################")
        print(ex['answer'])
        print(gen_text)
    
    
    total = len(preds)
    print("\n================= STATISTICS =================")
    print(f"Total: {total}")
    print(f"Accuracy: {acc_cnt / total}")
    print("==============================================")


if __name__ == "__main__":
    main()