from vllm import LLM, SamplingParams
import os
import sys
import json 
sys.path.append('/lustre/fast/fast/txiao/zly/spatial_head/cot')
import re
from tqdm import tqdm
from typing import List, Dict, Any, Tuple
from multiprocessing import Pool
import argparse
import torch
from transformers import AutoTokenizer 
from prompt import Domain2NL_template
from openai import OpenAI
import multiprocessing


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen2.5-7B-Instruct', help="Model directory")
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_all.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl.json', help="Output file")
    parser.add_argument("--Open_Ai_API_KEY", type=str, default='sk-RKWTkqvPVbGI09mKE7Ab43C126Bb4410B602845731928a5c', help="Open AI API KEY")
    parser.add_argument("--url", type=str, default='http://60.204.212.177:3000/v1', help="Open AI API URL")
    return parser.parse_args()

class domain2nl:
    def __init__(self, args) -> None:
        self.model_dir = args.model_dir
        self.input_file = args.input_file
        self.output_file = args.output_file
        self.domain_files = []
        # if inputfile end with json 
        if self.input_file.endswith('.json'):
            with open(self.input_file, 'r') as f:
                self.domain_files = json.load(f)
        elif self.input_file.endswith('.jsonl'):
            with open(self.input_file, 'r') as f:
                self.domain_files = [json.loads(line) for line in f.readlines()]
        else:
            raise ValueError('input_file must be a .json or .jsonl file')
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.sampling_params = SamplingParams(temperature=0.1, max_tokens=1024, top_p=0.9)
        self.model = LLM(model=self.model_dir, trust_remote_code=True, dtype="half")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
        self.batch_size = 15
    
    
    def apply_prompt_template(self, domain: str):
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": Domain2NL_template.format(domain=domain)}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt
    
    
    def generate_nl(self, domain_files: List[Dict[str, str]]):
        # batchify the domain files
        batches = [domain_files[i:i+self.batch_size] for i in range(0, len(domain_files), self.batch_size)]
        results = []
        for batch in tqdm(batches):
            results += self.generate_nl_batch(batch)
        return results
    
    def generate_nl_batch(self, data: List[Dict[str, str]]):
        domains = [d['code'] for d in data]
        # apply the prompt template
        prompts = [self.apply_prompt_template(domain) for domain in domains]
        # generate the natural language description
        answers = self.model.generate(prompts, self.sampling_params)
        results = []
        for i in range(len(data)):
            results.append(
                {
                    "file": data[i]['file'],
                    "code": data[i]['code'],
                    "nl_description": answers[i].outputs[0].text
                }
            )
        return results
    
    def write_output(self, results: List[Dict[str, str]]):
        with open(self.output_file, 'w') as f:
            json.dump(results, f, indent=4)

class domain2nl_openai:
    def __init__(self, args) -> None:
        self.key = args.Open_Ai_API_KEY
        self.url = args.url
        self.input_file = args.input_file
        # if inputfile end with json
        if self.input_file.endswith('.json'):
            with open(self.input_file, 'r') as f:
                self.data = json.load(f)
        elif self.input_file.endswith('.jsonl'):
            with open(self.input_file, 'r') as f:
                self.data = [json.loads(line) for line in f.readlines()]
        else:
            raise ValueError('input_file must be a .json or .jsonl file')
        self.process_num = multiprocessing.cpu_count()

    
    def generate_nl(self, domain: str):
        # translate the domain to natural language
        client = OpenAI(base_url=self.url, api_key=self.key)
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": "You are an expert in pddl"
                },
                {
                    "role": "user",
                    "content": Domain2NL_template.format(domain=domain),
                }
            ],
            model="gpt-4o-mini",
            top_p=0.9,
            temperature=0.7
        )
        return chat_completion.choices[0].message.content
    
    def jsondata2domain(self, data: Dict[str, Any]) -> Dict[str, Any]:
        nl = self.generate_nl(data['code'])
        result = {
            'file': data['file'],
            'code': data['code'],
            'nl_description': nl
        }
        return result
    
    def multi_process(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        with Pool(self.process_num) as p:
            result = list(tqdm(p.imap(self.jsondata2domain, data), total=len(data)))
        return result
    

    

if __name__ == "__main__":
    args = parse_args()
    # d2nl = domain2nl(args)
    # results = d2nl.generate_nl(d2nl.domain_files)
    # d2nl.write_output(results)

    d2nl_openai = domain2nl_openai(args)
    results = d2nl_openai.multi_process(d2nl_openai.data)
    with open(args.output_file, 'w') as f:
        json.dump(results, f, indent=4)
        
            
    
