import os
import sys
import json
from typing import List, Dict, Any, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
import torch
import argparse
from tqdm import tqdm
import pdb

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

Prompt_template = "You are an expert in pddl domain generation. Please generate the corresponding PDDL domain for the following problem code: {problem_code}\nThe generated pddl domain code is:```pddl"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen2.5-7B-Instruct', help="Model directory")
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/question_all.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Domain_gen/qwen_question_icp.json', help="Output file")
    return parser.parse_args()

class qwen_domain:
    def __init__(self, args):
        self.model = LLM(model=args.model_dir, trust_remote_code=True, dtype="half", tensor_parallel_size=4, disable_custom_all_reduce=True)
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.sampling_params = SamplingParams(temperature=0.8, max_tokens=4096, top_p=0.9)
        self.input_file = args.input_file
        self.output_file = args.output_file
        self.batch_size = 100
        # if the input file is json or jsonl
        if args.input_file.endswith("jsonl"):
            self.data = [json.loads(line) for line in open(args.input_file)]
        elif args.input_file.endswith("json"):
            self.data = json.load(open(args.input_file))
        else:
            raise ValueError("Input file format not supported")

    def apply_prompt_template(self, question: str):
        prompt_question = Prompt_template.format(problem_code=question)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": prompt_question}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt
    
    def generate_domain(self):
        questions = [self.apply_prompt_template(data['code']) for data in self.data]
        batches = [questions[i:i+self.batch_size] for i in range(0, len(questions), self.batch_size)]
        results = []
        for batch in tqdm(batches):
            answers = self.model.generate(
                batch,
                sampling_params=self.sampling_params
            )
            for i, answer in enumerate(answers):
                results.append(
                    answer.outputs[0].text
                )

        return results

if __name__ == "__main__":
    args = parse_args()
    domain_gen = qwen_domain(args)
    results = domain_gen.generate_domain()
    dump_results = []
    for i, result in enumerate(results):
        dump_results.append(
            {
                "question": domain_gen.data[i]['code'],
                "file": domain_gen.data[i]['file'],
                "domain": result
            }
        )
    with open(args.output_file, 'w') as f:
        json.dump(dump_results, f, indent=4)