import os
import sys
import json
from typing import List, Dict, Any, Tuple
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import torch
import argparse
from tqdm import tqdm
import pdb 
import multiprocessing
from multiprocessing import Pool
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

Prompt_template = "You are an expert in pddl domain generation.\n Please generate a pddl domain for the following natural language description: {NL_description} \nRemember: please make sure the `(:types` is included in the generated domain code! \n ```pddl"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default='/lustre/fast/fast/txiao/zly/ckpt/llama2_7b', help="Model directory")
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/llama_generated_domain_ipc.json', help="Output file")
    return parser.parse_args()

class llama_DomainGen:
    def __init__(self, args) -> None:
        self.model_dir = args.model_dir
        self.input_file = args.input_file
        self.output_file = args.output_file
        self.data = []
        # if inputfile end with json
        if self.input_file.endswith('.json'):
            with open(self.input_file, 'r') as f:
                self.data = json.load(f)
        elif self.input_file.endswith('.jsonl'):
            with open(self.input_file, 'r') as f:
                self.data = [json.loads(line) for line in f.readlines()]
        else:
            raise ValueError('input_file must be a .json or .jsonl file')
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.sampling_params = SamplingParams(temperature=0.1, max_tokens=1024, top_p=0.9)
        self.model = LLM(model=self.model_dir, trust_remote_code=True, dtype="half", tensor_parallel_size=8, disable_custom_all_reduce=True)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
        self.batch_size = 20
    
    def generate_domain(self, nl: str):
        messages = [
            {
                "role": "system",
                "content": "You are an expert in pddl domain generation.",
            },
            {"role": "user", "content": Prompt_template.format(NL_description=nl)}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        chat_completion = self.model.generate(
            msg_prompt,
            sampling_params=self.sampling_params
        )
        return chat_completion[0].outputs[0].text
    
    def jsondata2domain(self, data: Dict[str, Any]) -> Dict[str, Any]:
        domain = self.generate_domain(data['nl_description'])
        result = {
            'nl_description': data['nl_description'],
            'file': data['file'],
            'domain': domain
        }
        return result

    def apply_prompt_template(self, question: str):
        prompt_question = Prompt_template.format(NL_description=question)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": prompt_question}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt

    def run(self):
        results = []
        for data in tqdm(self.data):
            results.append(self.jsondata2domain(data))
        with open(self.output_file, 'w') as f:
            json.dump(results, f, indent=4)
    
    def generate_domain(self):
        questions = [self.apply_prompt_template(data['nl_description']) for data in self.data]
        batches = [questions[i:i+self.batch_size] for i in range(0, len(questions), self.batch_size)]
        results = []
        for batch in tqdm(batches):
            answers = self.model.generate(
                batch,
                sampling_params=self.sampling_params
            )
            for i, answer in enumerate(answers):
                results.append(
                    answer.outputs[0].text
                )
        return results

if __name__ == '__main__':
    args = parse_args()
    domain_gen = llama_DomainGen(args)
    results = domain_gen.generate_domain()
    dump_results = []
    for i, result in enumerate(results):
        dump_results.append(
            {
                "nl_description": domain_gen.data[i]['nl_description'],
                "file": domain_gen.data[i]['file'],
                "domain": result
            }
        )
    with open(args.output_file, 'w') as f:
        json.dump(dump_results, f, indent=4)