"""
This script test o1's planning capabilities on simple case studies
"""
# using multithreading to run multiple cases in parallel
import os
import sys
import json
import tqdm
from openai import OpenAI
import multiprocessing
from multiprocessing import Pool
from typing import List, Dict, Any, Tuple
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/llm-pddl/case_all.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/o1_case/o1preview_case.json', help="Output file")
    parser.add_argument("--Open_Ai_API_KEY", type=str, default='sk-RKWTkqvPVbGI09mKE7Ab43C126Bb4410B602845731928a5c', help="Open AI API KEY")
    parser.add_argument("--url", type=str, default='http://60.204.212.177:3000/v1', help="Open AI API URL")
    parser.add_argument("--model", type=str, default='o1-preview', help="Open AI model")
    return parser.parse_args()

class Openai_CaseStudy:
    def __init__(self, args) -> None:
        self.key = args.Open_Ai_API_KEY
        self.url = args.url
        self.input_file = args.input_file
        # if inputfile end with json
        if self.input_file.endswith('.json'):
            with open(self.input_file, 'r') as f:
                self.data = json.load(f)
        elif self.input_file.endswith('.jsonl'):
            with open(self.input_file, 'r') as f:
                self.data = [json.loads(line) for line in f.readlines()]
        else:
            raise ValueError('input_file must be a .json or .jsonl file')
        self.process_num = multiprocessing.cpu_count()
        self.model = args.model

    def generate_case(self, data: Dict[str, Any]) -> Dict[str, Any]:
        client = OpenAI(base_url=self.url, api_key=self.key)
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "assistant",
                    "content": "You are an expert in planning"
                },
                {
                    "role": "user",
                    "content": data['domain_nl'] + data['problem_nl'] + 'Examine whether you follow the rule at each step!',
                }
            ],
            model=self.model,
            max_tokens=12000
        )
        return chat_completion.choices[0].message.content
    
    def jsondata2case(self, data: Dict[str, Any]) -> Dict[str, Any]:
        o1_answer = self.generate_case(data)
        print(o1_answer)
        result = {
            'domain_type': data['domain_type'],
            'problem_name': data['problem_name'],
            'domain_nl': data['domain_nl'],
            'domain_code': data['domain_code'],
            'problem_nl': data['problem_nl'],
            'problem_code': data['problem_code'],
            'plan': data['plan'],
            'o1_plan': o1_answer,
        }
        return result

    def generate_critique(self, data: Dict[str, Any]) -> Dict[str, Any]:
        client = OpenAI(base_url=self.url, api_key=self.key)
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "assistant",
                    "content": "You are an expert in planning"
                },
                {
                    "role": "user",
                    "content": 'The plan task is: \n' + data['domain_nl'] + data['problem_nl'] + 'The answer is: \n' + data['plan'] + 'Now, please act as judger to check if the plan is optimal? No matter what the result is, revise the plan to make it better!',
                }
            ],
            model=self.model,
            max_tokens=12000
        )
        return chat_completion.choices[0].message.content

    def jsondata2critique(self, data: Dict[str, Any]) -> Dict[str, Any]:
        o1_critique = self.generate_critique(data)
        result = {
            'domain_type': data['domain_type'],
            'problem_name': data['problem_name'],
            'domain_nl': data['domain_nl'],
            'domain_code': data['domain_code'],
            'problem_nl': data['problem_nl'],
            'problem_code': data['problem_code'],
            'plan': data['plan'],
            'o1_plan': data['o1_plan'],
            'critique_revised_plan': o1_critique,
        }

        return result

    def generate_case_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with Pool(self.process_num) as p:
            results = list(tqdm.tqdm(p.imap(self.jsondata2case, data), total=len(data)))
        return results

    def generate_critique_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with Pool(self.process_num) as p:
            results = list(tqdm.tqdm(p.imap(self.jsondata2critique, data), total=len(data)))
        return results

    def generate_evaluation(self, data: Dict[str, Any]) -> Dict[str, Any]:
        client = OpenAI(base_url=self.url, api_key=self.key)
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "assistant",
                    "content": "You are an expert in planning"
                },
                {
                    "role": "user",
                    "content": 'The plan task is: \n' + data['domain_nl'] + data['problem_nl'] + 'The plan given by LLM is: \n' + data['o1_plan'] + 'The Ground Truth plan (100% true answer) is: \n' + data['plan'] + 'please check whether the plan given by LLM is correct or not, if it aligns with the ground truth answer then it is correct, else it is incorrect. In the end, output the final result of the evaluation in``【】``. e.g. The plan is 【corect】 or The plan is 【incorrect】'
                }
            ],
            model="gpt-4o-mini",
            max_tokens=9000
        )
        return chat_completion.choices[0].message.content

    def jsondata2evaluation(self, data: Dict[str, Any]) -> Dict[str, Any]:
        o1_evaluation = self.generate_evaluation(data)
        result = {
            'domain_type': data['domain_type'],
            'problem_name': data['problem_name'],
            'domain_nl': data['domain_nl'],
            'domain_code': data['domain_code'],
            'problem_nl': data['problem_nl'],
            'problem_code': data['problem_code'],
            'plan': data['plan'],
            'o1_plan': data['o1_plan'],
            'evaluation': o1_evaluation,
            'round': data['round']
        }

        return result

    def generate_evaluation_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with Pool(self.process_num) as p:
            results = list(tqdm.tqdm(p.imap(self.jsondata2evaluation, data), total=len(data)))
        return results

    def generate_critique_evaluation(self, data: Dict[str, Any]) -> Dict[str, Any]:
        client = OpenAI(base_url=self.url, api_key=self.key)
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "assistant",
                    "content": "You are an expert in planning"
                },
                {
                    "role": "user",
                    "content": 'The plan task is: \n' + data['domain_nl'] + data['problem_nl'] + 'The revised plan given by LLM is: \n' + data['critique_revised_plan'] + 'The Ground Truth plan given by pddl is: \n' + data['plan'] + 'please provide the evaluation of the revised plan against the ground truth plan, tell us whether the revised plan is optimal, and whether it breaks any rule. In the end, output the final result of the evaluation in``【】``. e.g. The plan is 【corect】 or The plan is 【incorrect】'
                }
            ],
            model="gpt-4o-mini",
            max_tokens=9000
        )
        return chat_completion.choices[0].message.content
    
    def generate_evaluation4revisedplan_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with Pool(self.process_num) as p:
            results = list(tqdm.tqdm(p.imap(self.jsondata2evaluation, data), total=len(data)))
        return results

    def jsondata2evaluation(self, data: Dict[str, Any]) -> Dict[str, Any]:
        o1_evaluation = self.generate_evaluation(data)
        try:
            result = {
                'domain_type': data['domain_type'],
                'problem_name': data['problem_name'],
                'domain_nl': data['domain_nl'],
                'domain_code': data['domain_code'],
                'problem_nl': data['problem_nl'],
                'problem_code': data['problem_code'],
                'plan': data['plan'],
                'o1_plan': data['o1_plan'],
                'round': data['round'],
                'evaluation': o1_evaluation,
            }
        except:
            result = {
                'domain_type': data['domain_type'],
                'problem_name': data['problem_name'],
                'domain_nl': data['domain_nl'],
                'domain_code': data['domain_code'],
                'problem_nl': data['problem_nl'],
                'problem_code': data['problem_code'],
                'plan': data['plan'],
                'o1_plan': data['o1_plan'],
                'evaluation': o1_evaluation,
            }

        return result

    def generate_case_multiple_rounds(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        results = []
        for round_num in range(8):
            o1_answer = self.generate_case(data)
            result = {
                'domain_type': data['domain_type'], 
                'problem_name': data['problem_name'],
                'domain_nl': data['domain_nl'],
                'domain_code': data['domain_code'],
                'problem_nl': data['problem_nl'],
                'problem_code': data['problem_code'],
                'plan': data['plan'],
                'o1_plan': o1_answer,
                'round': round_num
            }
            results.append(result)
        return results

    def generate_case_multiple_rounds_batch(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        # multithread
        with Pool(self.process_num) as p:
            results = list(tqdm.tqdm(p.imap(self.generate_case_multiple_rounds, data), total=len(data)))
        return results




if __name__ == '__main__':
    import pdb
    args = parse_args()
    case_study = Openai_CaseStudy(args)
    # result = case_study.generate_case_batch(case_study.data)
    # with open(args.output_file, 'w') as f:
    #     json.dump(result, f, indent=4)
    
    # case_study.data = result
    # evaluate_result = case_study.generate_evaluation_batch(case_study.data)
    # with open(args.output_file.replace('.json', '_evaluation.json'), 'w') as f:
    #     json.dump(evaluate_result, f, indent=4)
    
    # critique_result = case_study.generate_critique_batch(case_study.data)
    # with open(args.output_file.replace('.json', '_critique.json'), 'w') as f:
    #     json.dump(critique_result, f, indent=4)

    # case_study.data = critique_result
    # evaluate_revised_result = case_study.generate_evaluation4revisedplan_batch(case_study.data)
    # with open(args.output_file.replace('.json', '_evaluation_revised.json'), 'w') as f:
    #     json.dump(evaluate_revised_result, f, indent=4)

    # case_study.data = result
    # multiple_rounds_result = case_study.generate_case_multiple_rounds_batch(case_study.data)
    # with open(args.output_file.replace('.json', '_multiple_rounds.json'), 'w') as f:
    #     json.dump(multiple_rounds_result, f, indent=4)

    # evaluate multi rounds
    args.input_file = '/lustre/fast/fast/txiao/zly/spatial_head/cot/result/o1_case/o1preview_case_multiple_rounds.json'
    case_study = Openai_CaseStudy(args)
    # cascate the list(list) into list
    case_study.data = case_study.data = [item for sublist in case_study.data for item in sublist]
    evaluate_multiple_rounds_result = case_study.generate_evaluation_batch(case_study.data)
    with open(args.output_file.replace('.json', '_multiple_rounds_evaluation.json'), 'w') as f:
        json.dump(evaluate_multiple_rounds_result, f, indent=4)

    