import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
import random
import math
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from datasets import load_dataset
import json
import argparse
from vllm import LLM, SamplingParams
from training.utils.inference_finqa import load_json, build_prompt, parse_raw_to_steps, steps_to_eval_tokens

def vllm_generate(llm, tokenizer, prompts, args):

    sampling_params = SamplingParams(
        temperature=0.0,
        top_p=1.0,
        max_tokens=1024,
        skip_special_tokens=True
    )
    
    chat_prompts = []
    for prompt in prompts:
        chat_text = tokenizer.apply_chat_template(
            prompt,
            add_generation_prompt=True,
            tokenize=False,
            # enable_thinking=True
        )
        chat_prompts.append(chat_text)
    
    outputs = llm.generate(chat_prompts, sampling_params)
    
    results = []
    for output in outputs:
        generated_text = output.outputs[0].text
        results.append(generated_text)
    
    return results
    
def parse_args():
    parser = argparse.ArgumentParser(description='Obfuscate with random v')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to attack')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the original training dataset.')
    parser.add_argument('--output_path', type=str, required=True, help='The output path of recover dataset.')
    parser.add_argument('--ratio', type=float, default=0.01, help='ratio=len(recover_dataset)/len(train_dataset).')
    
    
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    
    vllm_kwargs = {
        "model": args.model_name_or_path,
        "tensor_parallel_size": torch.cuda.device_count(), 
        "trust_remote_code": True,
        "gpu_memory_utilization": 0.85,  
        "max_num_seqs": 256, 
        "max_num_batched_tokens": 4096,  
    }

    llm = LLM(**vllm_kwargs)
    
    tokenizer = AutoTokenizer.from_pretrained(vllm_kwargs["model"])
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    train_items = load_json(args.data_path)
    sample_size = max(1, math.ceil(len(train_items) * args.ratio))

    sampled_items = random.sample(train_items, sample_size)
    
    all_prompts = []
    for ex in sampled_items:
        qa = ex.get("qa", {}) or {}
        q = qa.get("question", "")
        table_obj = ex.get("table", ex.get("table_ori", ""))
        prompt = build_prompt(ex.get("pre_text", ""), ex.get("post_text", ""), table_obj, q)
        all_prompts.append(prompt)
    generated_texts = vllm_generate(llm, tokenizer, all_prompts, args)
    
    preds = []
    for i, (prompt, gen_text) in enumerate(zip(all_prompts, generated_texts)):
        ind = gen_text.find('</think>')
        if ind != -1:
            gen_text = gen_text[ind+8:].strip()
        prompt.append({"role": "assistant", "content": gen_text})
        preds.append(prompt)
    Path(args.output_path).write_text(json.dumps(preds, ensure_ascii=False, indent=2), encoding="utf-8")

    
if __name__ == "__main__":
    main()