import os
import sys
import json
import tqdm
from openai import OpenAI
import multiprocessing
from multiprocessing import Pool
from typing import List, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor
import argparse

Prompt_template = "You are an expert in pddl domain generation. Please generate the corresponding PDDL domain code, using the following problem code: {problem_code}"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/question_all.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Domain_gen/o1preview_problem_ipc.json', help="Output file")
    parser.add_argument("--Open_Ai_API_KEY", type=str, default='sk-RKWTkqvPVbGI09mKE7Ab43C126Bb4410B602845731928a5c', help="Open AI API KEY")
    parser.add_argument("--url", type=str, default='http://60.204.212.177:3000/v1', help="Open AI API URL")
    return parser.parse_args()

class Openai_DomainGen:
    def __init__(self, args) -> None:
        self.key = args.Open_Ai_API_KEY
        self.url = args.url
        self.input_file = args.input_file
        # if inputfile end with json
        if self.input_file.endswith('.json'):
            with open(self.input_file, 'r') as f:
                self.data = json.load(f)
        elif self.input_file.endswith('.jsonl'):
            with open(self.input_file, 'r') as f:
                self.data = [json.loads(line) for line in f.readlines()]
        else:
            raise ValueError('input_file must be a .json or .jsonl file')
        self.process_num = multiprocessing.cpu_count()

    def generate_domain(self, nl: str):
        client = OpenAI(base_url=self.url, api_key=self.key)
        # chat_completion = client.chat.completions.create(
        #     messages=[
        #         {
        #             "role": "system",
        #             "content": "You are an expert in pddl domain generation."
        #         },
        #         {
        #             "role": "user",
        #             "content": Prompt_template.format(problem_code=nl),
        #         }
        #     ],
        #     model="o1-mini",
        #     top_p=0.9,
        #     temperature=0.7
        # )
        try:
            chat_completion = client.chat.completions.create(
                messages=[
                    {
                        "role": "assistant",
                        "content": "You are an expert in pddl domain generation."
                    },
                    {
                        "role": "user",
                        "content": Prompt_template.format(problem_code=nl),
                    }
                ],
                model = 'o1-preview',
                max_tokens=7200
            )
            return chat_completion.choices[0].message.content
        except Exception as e:
            return ""
    
    def jsondata2domain(self, data: Dict[str, Any]) -> Dict[str, Any]:
        domain = self.generate_domain(data['code'])
        result = {
            'problem': data['code'],
            'file': data['file'],
            'domain': domain
        }
        return result
    
    def generate_domain_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with ThreadPoolExecutor(max_workers=64) as executor:
            # results = list(executor.map(self.jsondata2domain, data))
            # use tqdm to show progress
            results = list(tqdm.tqdm(executor.map(self.jsondata2domain, data), total=len(data)))
        return results
        
    
if __name__ == '__main__':
    args = parse_args()
    domain_gen = Openai_DomainGen(args)
    result = domain_gen.generate_domain_batch(domain_gen.data)
    with open(args.output_file, 'w') as f:
        json.dump(result, f, indent=4)
    print(f"output file saved in {args.output_file}")