import os
import sys
import json
import tqdm
from openai import OpenAI
import multiprocessing
from multiprocessing import Pool
from typing import List, Dict, Any, Tuple
import argparse
from concurrent.futures import ThreadPoolExecutor
os.environ['API_KEY'] = ' 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'
Prompt_template = "You are an expert in pddl domain generation. Please generate a pddl domain for the following natural language description: {NL_description} ```pddl\n"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Domain_gen/o1preview_nl.json', help="Output file")
    parser.add_argument("--Open_Ai_API_KEY", type=str, default='sk-RKWTkqvPVbGI09mKE7Ab43C126Bb4410B602845731928a5c', help="Open AI API KEY")
    parser.add_argument("--url", type=str, default='http://60.204.212.177:3000/v1', help="Open AI API URL")
    return parser.parse_args()

class Openai_DomainGen:
    def __init__(self, args) -> None:
        self.key = args.Open_Ai_API_KEY
        self.url = args.url
        self.input_file = args.input_file
        # if inputfile end with json
        if self.input_file.endswith('.json'):
            with open(self.input_file, 'r') as f:
                self.data = json.load(f)
        elif self.input_file.endswith('.jsonl'):
            with open(self.input_file, 'r') as f:
                self.data = [json.loads(line) for line in f.readlines()]
        else:
            raise ValueError('input_file must be a .json or .jsonl file')
        self.process_num = multiprocessing.cpu_count()

    def generate_domain(self, nl: str):
        client = OpenAI(base_url=self.url, api_key=self.key)
        # chat_completion = client.chat.completions.create(
        #     messages=[
        #         {
        #             "role": "system",
        #             "content": "You are an expert in pddl domain generation."
        #         },
        #         {
        #             "role": "user",
        #             "content": Prompt_template.format(NL_description=nl),
        #         }
        #     ],
        #     model="gpt-4o",
        #     top_p=0.9,
        #     temperature=0.7
        # )
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "assistant",
                    "content": "You are an expert in pddl domain generation."
                },
                {
                    "role": "user",
                    "content": Prompt_template.format(NL_description=nl),
                }
            ],
            model = 'o1-preview',
            max_tokens=7200
        )
        return chat_completion.choices[0].message.content
    
    def jsondata2domain(self, data: Dict[str, Any]) -> Dict[str, Any]:
        domain = self.generate_domain(data['nl_description'])
        print(domain)
        result = {
            'nl_description': data['nl_description'],
            'file': data['file'],
            'domain': domain
        }
        return result
    
    def generate_domain_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with ThreadPoolExecutor(max_workers=64) as executor:
            # results = list(executor.map(self.jsondata2domain, data))
            # use tqdm to show progress
            results = list(tqdm.tqdm(executor.map(self.jsondata2domain, data), total=len(data)))
        return results
        

    
if __name__ == '__main__':
    import pdb 
    args = parse_args()
    domain_gen = Openai_DomainGen(args)
    data = domain_gen.data
    results = []
    # for sample in data:
    #     result = domain_gen.jsondata2domain(sample)
    #     results.append(result)
    results = domain_gen.generate_domain_batch(data)
  
    with open(args.output_file, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"output file saved in {args.output_file}")