# plot the val table


import os
import pandas as pd
import numpy as np
import argparse
import json
import re
import pdb
import matplotlib.pyplot as plt
#/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl_bad.json /lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val
def args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--out_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/same_planet.png')
    parser.add_argument('--test_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/planetarium/planetarium_test.json', help='Path to input file')
    parser.add_argument('--input_file', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/planetarium/coder16prob_new.json', help='Path to input file')
    
    return parser.parse_args()
    
class Visual:
    def __init__(self, args):
        self.out_path = args.out_path
        self.test_path = args.test_path
        self.input_file = args.input_file
        with open(self.input_file, 'r') as f:
            self.data = json.load(f)
        with open(self.test_path, 'r') as f:
            self.test_data = json.load(f)
        self.ff = False
    
    def get_dir_number(self):
        path = '/lustre/fast/fast/txiao/zly/spatial_head/cot/result/fast_downward'
        results = {}
        models = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        for model in models:
            # get the dir number in the dir
            dir_path = os.path.join(path, model)          
            num_dirs = len([d for d in os.listdir(dir_path)])
          
            results[model] = {         
                'dir_num': num_dirs,
            }
        
        return results   

    def re(self, str):
        pattern = r'```pddl\n(.*?)\n```'
            # r'```lisp\n(.*?)\n```'
        
        match = re.search(pattern, str, re.DOTALL)
        if match:
            str = match.group(1)
        # get the '(goal' part and include '(goal' in the result
        if '(:goal' in str:
            str = str[str.index('(:goal'):]
       
        str = str.replace(" ", "").replace("\n", "")
        return str

    def process_data_planet(self):
        # process the data based on the pass@1 
    
        right = 0
        wrong = 0
    
        names = list(set([i['name'] for i in self.test_data]))   

        for name in names:
            goals=[i['question'] for i in self.data if i['name'] == name]
            test_goal = list(set([i['question'] for i in self.test_data if i['name'] == name]))  
            if len(test_goal) == 1: 
                test_goal = test_goal[0]
                test_goal = test_goal[test_goal.index('(:goal'):]
                test_goal = test_goal.replace(" ", "").replace("\n", "")
            right_flag = 0
            wrong_flag = 0
           
            for goal in goals:  
                goal = self.re(goal)
                # pdb.set_trace()
                if goal == test_goal:
                    right_flag = 1
                    break
                else:
                    wrong_flag = 1

           
            if right_flag:
                right += 1
            elif right_flag == 0 and wrong_flag == 1:
                wrong += 1
            else:
                print("xxxxx")

        return right, wrong

    def plot_table(self):
        model_name = self.input_file.split('.json')[0]
        result = {}
        right, wrong= self.process_data_planet()
        pdb.set_trace()  
        right_rate = right / (right + wrong)
        
        result[model_name] = {
            # 'right+bad': right + bad_domain,
            'right': right,
            'wrong': wrong,
            'right_rate': f"{right_rate:.3f}", #f"{right_rate * 100:.2f}%",
            'all':right + wrong 
        }
        if self.ff:
            ff_number = self.get_dir_number()           
            for model_name, _ in result.items():
                num = ff_number[model_name]['dir_num'] if model_name in ff_number else 0
                result[model_name]['fast_downward'] = num
        
        df = pd.DataFrame.from_dict(result, orient='index').reset_index().rename(columns={'index': 'model'})
        # df.to_csv(self.out_path,index=False)
    
        plt.figure(figsize=(10, 6))
        ax = plt.gca()  # 获取当前的 Axes 对象
        ax.axis('off')  # 关闭坐标轴
        ax.axis('tight')
        plt.title('planet same as goal Table(IPC)')
        table = ax.table(cellText=df.values, colLabels=df.columns, loc='center')

        # 调整表格样式
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1, 1.5)

        # 保存表格为图像文件
        plt.savefig(self.out_path, bbox_inches='tight')

        # 显示表格
        plt.show()

        
if __name__ == '__main__':
    args = args()
    v = Visual(args)
    # v.process_data_text_grad()
    # result = v.get_dir_number()
    v.plot_table()

    
