import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
import os
import random
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from datasets import load_dataset
import json
import argparse
from vllm import LLM, SamplingParams
import re


def build_prompt(example):
    question = example['question']
    context = example['context']
    # Clear, minimal formatting
    system_message = (
        "You are a helpful AI assistant designed to answer biomedical research questions based on provided abstracts from scientific papers.\n"
        "Your task is to carefully read the 'CONTEXT' (which is an abstract from a PubMed article) and then answer the 'QUESTION' regarding that context.\n"
        "Instead of jumping to conclusions, you must carefully "
        "think through the evidence step by step before giving the final answer.\n"
        "And the final answer must be 'yes' or 'no' or 'maybe'\n"
    )
    user_message = (   
        "## CONTEXT (PubMed Abstract):\n"
        f"{context}\n\n"
        "## QUESTION:\n"
        f"{question}"
    )

    return {'messages': [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ]}
        

def parse_args():
    parser = argparse.ArgumentParser(description='Obfuscate with random v')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to attack')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the original training dataset.')
    parser.add_argument('--output_path', type=str, required=True, help='The output path of recover dataset.')
    parser.add_argument('--ratio', type=float, default=0.01, help='ratio=len(recover_dataset)/len(train_dataset).')
    
    
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    
    vllm_kwargs = {
        "model": args.model_name_or_path,
        "tensor_parallel_size": torch.cuda.device_count(), 
        "trust_remote_code": True,
        "gpu_memory_utilization": 0.85,  
        "max_num_seqs": 256, 
        "max_num_batched_tokens": 4096,  
    }

    llm = LLM(**vllm_kwargs)
    
    tokenizer = AutoTokenizer.from_pretrained(vllm_kwargs["model"])
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    train_ds = load_dataset("json", data_dir=args.data_path)['train']
    
    recover_ds = train_ds.train_test_split(args.ratio, shuffle=True)['test']
    recover_ds = recover_ds.map(build_prompt, remove_columns=recover_ds.column_names)
    
    sampling_params = SamplingParams(
        temperature=0.0,
        top_p=1.0,
        max_tokens=1024,
        skip_special_tokens=True
    )
    
    chat_prompts = []
    for prompt in recover_ds:
        chat_text = tokenizer.apply_chat_template(
                        prompt['messages'],
                        add_generation_prompt=True,
                        tokenize=False,
                        # enable_thinking=True
                    )
        chat_prompts.append(chat_text)
    
    outputs = llm.generate(chat_prompts, sampling_params)
    
    results = []
    for output in outputs:
        generated_text = output.outputs[0].text
        results.append(generated_text)
    
    preds = []
    for i, (ex, gen_text) in enumerate(zip(recover_ds, results)):
        text_lower = gen_text.lower()
        answer_pos = text_lower.find("answer:")
        
        if answer_pos != -1:
            long_answer = gen_text[:answer_pos].strip()
            final_decision = gen_text[answer_pos + 7:].strip()
            assistant_messgae = (
                f"{long_answer}\n\n"
                f"Answer: {final_decision}\n"
            )
        else:
            print("##########################\n")
            print(gen_text)
            assistant_messgae = (
                f"None\n\n"
                f"Answer: Unknown\n"
            )
        
        prompt = ex['messages']
        prompt.append({"role": "assistant", "content": assistant_messgae})
        
        preds.append(prompt)
    Path(args.output_path).write_text(json.dumps(preds, ensure_ascii=False, indent=2), encoding="utf-8")
        

    
if __name__ == "__main__":
    main()