# plot the val table


import os
import pandas as pd
import numpy as np
import argparse
import json
import pdb
import matplotlib.pyplot as plt
#/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl_bad.json /lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val
def args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val/baseline', help='Path to input directory')
    parser.add_argument('--out_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/baseline.csv')
    parser.add_argument('--image_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/baseline.png')
    parser.add_argument('--input_file', type=str, default="/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val/GradientD/coder8_80_prob.json", help='Path to input file')
    return parser.parse_args()
    
class Visual:
    def __init__(self, args):
        self.input_dir = args.input_dir
        self.out_path = args.out_path
        self.image_path = args.image_path
        self.input_file = args.input_file
        self.ff = False
    
    def get_dir_number(self):
        path = '/lustre/fast/fast/txiao/zly/spatial_head/cot/result/fast_downward'
        results = {}
        models = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
        for model in models:
            # get the dir number in the dir
            dir_path = os.path.join(path, model)          
            num_dirs = len([d for d in os.listdir(dir_path)])
          
            results[model] = {         
                'dir_num': num_dirs,
            }
        
        return results   

    def process_data(self, data_file):
        bad_domain = 0
        right = 0
        wrong = 0
        with open (data_file, 'r') as f:
            data = json.load(f)
        result = [i['result'] for i in data]
        for i in range(len(result)):
            if result[i] == "":
                bad_domain += 1
            # result[i] begins with "Type-checking" and no "fail" and 'incorrectly' in it, then it is true
            elif result[i].startswith("Type-checking") and "fail" not in result[i] and "incorrectly" not in result[i]:
                right += 1
            else:
                wrong += 1
        return right, wrong, bad_domain
    
    def process_data_text_grad(self):
        
        with open (self.input_file, 'r') as f:
            datas = json.load(f)
        
        text = {}
        for i in range(0,80):
            data_i = [data for data in datas if data['rounds'] < (i+1)]
            files = list(set([data['file'] for data in datas]))   
            print(f"process the rounds small than {i}")
            bad_domain = 0
            right = 0
            wrong = 0
            right_flag = 0
            wrong_flag = 0
            bad_domain_flag = 0
            # pdb.set_trace()
            for file in files:
                results=[data['result'] for data in data_i if data['file'] == file]
                right_flag = 0
                wrong_flag = 0
                bad_domain_flag = 0
                for result in results:
                    # pdb.set_trace()
                    if result.startswith("Type-checking") and "fail" not in result and "incorrectly" not in result: #and "Object with unknown type" not in result
                        right_flag = 1
                        break
                    elif result.startswith("Type-checking"):
                        wrong_flag = 1
                        # if 'Errors:' in result:
                        #     if 'Errors: 0' in result:
                        #         right_flag = 1
                        #     else: 
                        #         wrong_flag = 1
                        # elif "fail" in result and:
                        #     wrong_flag = 1

                    else:
                        bad_domain_flag = 1


                if right_flag:
                    right += 1
                elif right_flag == 0 and wrong_flag == 1:
                    wrong += 1
                else:
                    bad_domain += 1
                if 'prob' in self.input_file:
                    model=f'coder8_{i}_prob',
                else:
                    model=f'coder8_{i}_nl',
                text[model] = {
                    # 'model': f'coder8_{i}_nl'
                    'round': i+1,
                    'right': right, 
                    'wrong': wrong, 
                    'bad_domain':bad_domain,
                    'all':right + wrong + bad_domain
                } 
            
        df = pd.DataFrame.from_dict(text, orient='index').reset_index().rename(columns={'index': 'model'})
        df.to_csv(self.out_path,index=False)
    
        plt.figure(figsize=(12, 6))
        ax = plt.gca()  # 获取当前的 Axes 对象
        ax.axis('off')  # 关闭坐标轴
        ax.axis('tight')
        plt.title('Text gard Model Val Table(IPC)')
        table = ax.table(cellText=df.values, colLabels=df.columns, loc='center')

        # 调整表格样式
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1, 1.5)

        # 保存表格为图像文件
        if 'nl' in self.input_file:
            output_image_path = '/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/gard_nl.png'
        else:
            output_image_path = '/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/gard_prob.png'
        
        plt.savefig(output_image_path, bbox_inches='tight')

        plt.show()

    def process_data_pass(self, data_file):
        # process the data based on the pass@1 
        bad_domain = 0
        right = 0
        wrong = 0
        with open (data_file, 'r') as f:
            data = json.load(f)
       
        files = sorted(list(set([i['file'] for i in data])))      
       
        for file in files:
            results=[i['result'] for i in data if i['file'] == file]
            right_flag = 0
            wrong_flag = 0
            bad_domain_flag = 0
           
            for result in results:    
                if result.startswith("Type-checking") and "fail" not in result and "incorrectly" not in result: #and "Object with unknown type" not in result
                    right_flag = 1
                    break
                elif result.startswith("Type-checking"):
                    wrong_flag = 1
                    # if 'Errors:' in result:
                    #     if 'Errors: 0' in result:
                    #         right_flag = 1
                    #     else: 
                    #         wrong_flag = 1
                    # elif "fail" in result and:
                    #     wrong_flag = 1

                else:
                    bad_domain_flag = 1
            if right_flag:
                right += 1
            elif right_flag == 0 and wrong_flag == 1:
                wrong += 1
            else:
                bad_domain += 1
        return right, wrong, bad_domain


    def process_data_planet(self, data_file):
        # process the data based on the pass@1 
        bad_domain = 0
        right = 0
        wrong = 0
        with open (data_file, 'r') as f:
            data = json.load(f)
       
        files = list(set([i['name'] for i in data]))   

        for file in files:
            results=[i['result'] for i in data if i['name'] == file]
            right_flag = 0
            wrong_flag = 0
            bad_domain_flag = 0
           
            for result in results:  
                # pdb.set_trace()  
                if result.startswith("Type-checking") and "Errors: 0" in result:
                    right_flag = 1
                    break
                elif result.startswith("Type-checking") and "Error" in result:
                    # pdb.set_trace()
                    wrong_flag = 1

                else:
                    bad_domain_flag = 1
            if right_flag:
                right += 1
            elif right_flag == 0 and wrong_flag == 1:
                wrong += 1
            else:
                bad_domain += 1
        return right, wrong, bad_domain
    def plot_table(self):
        result = {}
        # read the file in the input directory
        files = [f for f in os.listdir(self.input_dir)]
        files = sorted(list(set(files)))   
        for file in files:    
            model_name = file.split('.json')[0]
            if 'coder' in model_name:
                right, wrong, bad_domain = self.process_data_pass(os.path.join(self.input_dir, file))
               
            else:
                right, wrong, bad_domain = self.process_data_pass(os.path.join(self.input_dir, file))
               
        
            right_rate = right / (right + wrong + bad_domain)
            
            result[model_name] = {
                # 'right+bad': right + bad_domain,
                'right_rate': f"{right_rate:.3f}",
                'right': right,
                'wrong': wrong,
                'bad_domain': bad_domain,
                # 'right_rate': f"{right_rate:.3f}", #f"{right_rate * 100:.2f}%",
                'all':right + wrong + bad_domain
            }
        if self.ff:
            ff_number = self.get_dir_number()           
            for model_name, _ in result.items():
                num = ff_number[model_name]['dir_num'] if model_name in ff_number else 0
                result[model_name]['fast_downward'] = num
        
        df = pd.DataFrame.from_dict(result, orient='index').reset_index().rename(columns={'index': 'model'})
        df.to_csv(self.out_path,index=False)
    
        plt.figure(figsize=(26, 6))
        ax = plt.gca()  # 获取当前的 Axes 对象
        ax.axis('off')  # 关闭坐标轴
        ax.axis('tight')
        plt.title('baseline Model Val Table(IPC)')
        table = ax.table(cellText=df.values, colLabels=df.columns, loc='center')

        # 调整表格样式
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1, 1.5)

        # 保存表格为图像文件
        output_image_path = self.image_path
        plt.savefig(output_image_path, bbox_inches='tight')

        # 显示表格
        plt.show()

        
if __name__ == '__main__':
    args = args()
    v = Visual(args)
    
    # v.process_data_text_grad()
    # result = v.get_dir_number()
    v.plot_table()


    
