
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import pprint
pp = pprint.PrettyPrinter(indent=4)

# folders = {
#     'final_results': 'evaluation',
#     'final_results_sim': 'evaluation_simultaneous',
# }

folders = 'final_results'
files = [
    'evaluation',
    'evaluation_simultaneous',
]

scale_data_set = [
    "data/problem_set_r10_t20_s0_f10_w25_euc_200_test",
    "data/problem_set_r10_t50_s0_f10_w25_euc_2000_uni",
    "data/problem_set_r20_t100_s0_f10_w25_euc_200",
    "data/problem_set_r40_t200_s0_f10_w25_euc_200"
]

num_problems = [
    200,
    200,
    30,
    10
]

scales = [
    # '10R20T',
    # '10R50T',
    # '20R100T',
    # '40R200T',
    'Small',
    'Medium',
    'Large',
    'XLarge',
]

scale_nums = [
    20,
    50,
    100,
    200
]

poly_scale=[
    10 * 20 * 20, 
    10 * 50 **2,
    20 * 100**2,
    40 * 200**2
]


task_sizes = [20, 50, 100, 200]
agent_sizes = [10, 10, 20, 40]


sensitivity_data_set = [
    "data/problem_set_r10_t20_s0_f5_w25_euc_2000",
    "data/problem_set_r10_t20_s5_f10_w25_euc_2000",
    "data/problem_set_r10_t20_s0_f10_w0_euc_2000",
    "data/problem_set_r10_t20_s0_f10_w50_euc_2000",
    "data/problem_set_r10_t20_s0_f10_w25_euc_2000_test_slow",
    "data/problem_set_r10_t20_s0_f10_w25_euc_2000_test_fast"
]

sensitivity_num_problems = [200] * 6
data_types = [
    'makespan',
    'infeasible',
    'makespan_optimality_rate',
    'makespan_optimality_gap'
]

data_labels = {
    'makespan': 'Reward',
    'infeasible': 'Feasible (%)',
    'makespan_optimality_rate': 'Optimality Rate (%)',
    'makespan_optimality_gap': 'Optimality Gap'
}

sensitivity_scales = [
    '0%S - 5%F',
    '5%S - 10%F',
    'W 0%',
    'W 50%',
    'Slow',
    'Fast'
]


order = [
'milp_solver',
'edf',
'improved_edf',
'gen_random',
'gen_edf',
'gen_random_3',
'gen_edf_3',
# 'gen_random_10',
# 'gen_edf_10',
'hetgat_10',
'hetgat_11',
'hetgat_12',
'hetgat_resnet_10',
'hetgat_resnet_11',
'hetgat_resnet_12',
'hgt_10',
'hgt_11',
'hgt_12',
'hgt_edge_10',
'hgt_edge_11',
'hgt_edge_12',
'hgt_edge_resnet_10',
'hgt_edge_resnet_11',
'hgt_edge_resnet_12',
'simultaneous_hgt_10',
'simultaneous_hgt_11',
'simultaneous_hgt_12',
'simultaneous_hgt_edge_10',
'simultaneous_hgt_edge_11',
'simultaneous_hgt_edge_12',
'simultaneous_hgt_edge_resnet_10',
'simultaneous_hgt_edge_resnet_11',
'simultaneous_hgt_edge_resnet_12',
]

base_models = [
'hetgat',
'hetgat_resnet',
'hgt',
'hgt_edge',
'hgt_edge_resnet',
'simultaneous_hgt',
'simultaneous_hgt_edge',
'simultaneous_hgt_edge_resnet',
# 'simultaneous_hgt_batch_8',
# 'simultaneous_hgt_edge_batch_8',
# 'simultaneous_hgt_edge_resnet_batch_8',
]

label_keys = {
'edf': "EDF",
'improved_edf': "CA-EDF",
'milp_solver': "MILP Solver",
'gen_random': "Gen-Random 1",
"gen_edf": "Gen-EDF 1",
'gen_random_3': "Gen-Random 3",
'gen_edf_3': "Gen-EDF 3",
# 'gen_random_10': "Gen-Random 10",
# 'gen_edf_10': "Gen-EDF 10",
'hetgat_10': "HetGAT-10",
'hetgat_11': "HetGAT-11",
'hetgat_12': "HetGAT-12",
'hetgat_resnet_10': "Res-HetGAT-10",
'hetgat_resnet_11': "Res-HetGAT-11",
'hetgat_resnet_12': "Res-HetGAT-12",
'hgt_10': "Seq-TARGETNET\\ER-10",
'hgt_11': "Seq-TARGETNET\\ER-11",
'hgt_12': "Seq-TARGETNET\\ER-12",
'hgt_edge_10': "Seq-TARGETNET\\R-10",
'hgt_edge_11': "Seq-TARGETNET\\R-11",
'hgt_edge_12': "Seq-TARGETNET\\R-12",
'hgt_edge_resnet_10': "Seq-TARGETNET-10",
'hgt_edge_resnet_11': "Seq-TARGETNET-11",
'hgt_edge_resnet_12': "Seq-TARGETNET-12",
'simultaneous_hgt_10': "TARGETNET\\ER-10",
'simultaneous_hgt_11': "TARGETNET\\ER-11",
'simultaneous_hgt_12': "TARGETNET\\ER-12",
'simultaneous_hgt_edge_10': "TARGETNET\\R-10",
'simultaneous_hgt_edge_11': "TARGETNET\\R-11",
'simultaneous_hgt_edge_12': "TARGETNET\\R-12",
'simultaneous_hgt_edge_resnet_10': "TARGETNET (Ours)-10",
'simultaneous_hgt_edge_resnet_11': "TARGETNET (Ours)-11",
'simultaneous_hgt_edge_resnet_12': "TARGETNET (Ours)-12",
# 'simultaneous_hgt_batch_8': "Sim-HGT-8",
# 'simultaneous_hgt_edge_batch_8': "Sim-HGT-Edge-8",
# 'simultaneous_hgt_edge_resnet_batch_8': "Sim-Res-HGT-Edge-8",
}


def get_data():
    
    ret = {}
    for data_set, title, sizes, scale in zip([scale_data_set, sensitivity_data_set], ["Scale", "Sensitivity"], [num_problems, sensitivity_num_problems], [scales, sensitivity_scales]):
        ret[title] = {}
        for data_type in data_types: # makespan, infeasible...
            ret[title][data_type] = {}
            for i, (directory, size, problem_scale) in enumerate(zip(data_set, sizes, scale)): # scale
                ret[title][data_type][problem_scale] = {}
                for folder, prefix in zip([folders] * len(files), files): # evaluation and evaluation_simultaneous
                    file_name = f"{prefix}__{'_'.join(directory.split('/'))}"
                    file_path = os.path.join(folder, file_name)
                    
                    print(file_path)
                    data_file = f"{file_path}_{size}_seedwise_{data_type}.txt"
                    # read file
                    with open(data_file, 'r') as f:
                        
                        lines = f.readlines()
                        if len(lines) == 0:
                            print(f"File {data_file} is empty")
                            continue
                        for line in lines:    
                            # print(line)
                            if line.startswith('#') or line == '\n':
                                continue
                            key, data = line.strip().split(':')
                            if '__10r20t0s10f25w' in key:
                                # remove the '__10r20t0s10f25w' from the key
                                key = key.replace('__10r20t0s10f25w', '')
                            if key not in label_keys:
                                print(f"Key {key} not in label_keys")
                                continue
                            mean, standard_deviation = data.split(',')
                            mean = float(mean)
                            standard_deviation = float(standard_deviation)
                            
                            if 'infeasible' in data_type:
                                if 'Scale' in title:
                                    mean = (1.0 - mean / scale_nums[i]) * 100
                                    standard_deviation = standard_deviation / scale_nums[i] * 100
                                else:
                                    mean = (1.0 - mean / 20) * 100
                                    standard_deviation = standard_deviation / 20 * 100
                            if 'optimality_rate' in data_type:
                                mean = mean * 100
                                standard_deviation = standard_deviation * 100
                            print(f"{key}: {mean} ± {standard_deviation}")
                            ret[title][data_type][problem_scale][key] = [mean, standard_deviation]
                            
          
    # print return with pretty print
    # pp.pprint(ret)
    
    return ret


import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit

def make_bar_plots(data, i, title='Makespan vs Problem Set', xlabel='Problem Set', ylabel_name='Makespan', log=False, scale_labels=None, log_scale=False, normalized=False):
    
    pp.pprint(data)
    
    x_axis_labels = [scales, sensitivity_scales][i]
    # print(x_axis_labels, i, data.keys())
    # print('--' * 20)
    # print(data[x_axis_labels[0]].keys())
    # print(data[scales[0]].keys() if i == 0 else data[sensitivity_scales[0]].keys())
    labels = data[x_axis_labels[0]].keys() # This is wrong, use data
    # labels = data[scales[0]].keys() if i == 0 else data[sensitivity_scales[0]].keys()
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(labels)))
    
    
    # colors20 = plt.colormaps['tab20'].colors
    # colors20b = plt.colormaps['tab20b'].colors
    # # remove indices 1, 4,, 9, 15
    # colors = colors20[:1] + colors20[2:4] + colors20b[4:8] + colors20b[9:14] + colors20b[17:19] + colors20[6:7]
    
    # # colors = colors20[0, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 16, 17, 19]
    # colors = colors20[18:19] + colors20[2:8] + colors20[10:15] + colors20[16:18] + colors20[0:1]
    
    # use tab10, add texture
    colors = plt.colormaps['tab10'].colors
    
    color_indices = [0] + [1]*2 + [2]*4 + [4]*2*3 + [9]*3*3 + [3]*3*3
    color_textures = [None] + ['/', None] + ['/', '\\', 'o', None] + \
                        [None, None, None, '\\', '\\', '\\'] + \
                        ['/', '/', '/', '\\', '\\', '\\', None, None, None] + \
                        ['/', '/', '/', '\\', '\\', '\\', None, None, None]
                # ,     1,   1,    2,   2,    2,   2,   
                #      4,   4,    9,   9,    9,    3,   3,   3]
    # color_textures = [None, '/', None, '/', '\\', 'o', None, None, '\\', '/', '\\', None, '/', '\\', None]
    colors = [colors[i] for i in color_indices]
    
    # increase contrast by spllitting the colors into two sets and alternating them
    # print(len(colors))
    # color_first= colors[:len(labels)//2+1].copy()
    # color_second = colors[len(labels)//2+1:].copy()
    # colors = []
    # print(len(color_first), len(color_second))
    
    # i = 0
    # while len(colors)< len(labels):
    #     if i < len(color_first):
    #         colors.append(color_first[i])
    #     if i < len(color_second):
    #         colors.append(color_second[i])
    #     i += 1
    #     print(i, len(labels), len(colors), len(color_first), len(color_second))
    if i == 0:
        plt.figure(figsize=(30, 15))
    else:
        plt.figure(figsize=(40, 15))
    if log: # log scale
        plt.yscale('log')
        # plt.ylim(1e-2, 1e2)
    # change font to be bigger
    plt.rcParams.update({'font.size': 38})
    
    width = 0.02
    linewidth = width * 0.9
    
    label_ticks = [[], []]
    for k, label in enumerate(order):
        # if normalized and label in ['milp_solver']:
            # continue
        for j, scale in enumerate(x_axis_labels):
            if label not in data[scale].keys():
                continue
            mean, std = data[scale][label]
            label_key = label_keys[label]
            # print(f"{label_key}: {mean} ± {std}")
            if color_textures[k] is not None:
                plt.bar(j + k * width, mean, label=label_key, color=colors[k], hatch=color_textures[k], width=linewidth)
            else:
                plt.bar(j + k * width, mean, label=label_key, color=colors[k], width=linewidth)
            # y_max = min(1.0, means[j] + stds[j])
            # y_min = max(0.0, means[j] - stds[j])
            y_max = mean + std
            y_min = mean - std
            # if infeasible limit y_max to 100
            if 'infeasible' in ylabel_name:
                y_max = min(100.0, y_max)
                y_min = max(0.0, y_min)
            plt.errorbar(j + k * width, mean, yerr=[[mean - y_min], [y_max - mean]], color='k', lw=2, capsize=5, capthick=2)
            
            # plt.text(j + k * width, y_max + 0.02, label_key, ha='center', va='bottom', fontsize=8, rotation=90)
            # x-tick the labels at the bottom center of the bar
            # ticks for each label beneath the bars
            label_ticks[0].append(j + k * width)
            label_ticks[1].append(label_key)
              
    # dashed line on milp_solver
    if normalized and 'milp_solver' in order:
        plt.axhline(y=100.0, color='k', linestyle='--', linewidth=2)
    # elif 'infeasible' in ylabel_name:
    #     plt.axhline(y=100.0, color='k', linestyle='--', linewidth=2)
    # ylabel
    # pretty print data
    # pp.pprint(data)
    # x-axis labels are the scales from the x_axis_labels, slightly below the bars to give space for the data type labels
    # plt.xticks(np.arange(len(x_axis_labels)) + width * (len(labels)) / 2, x_axis_labels, rotation=45)
    # if i == 0:
    
    ## Uncomment this for specific type in text form on x-axis
    # plt.xticks(label_ticks[0], label_ticks[1], rotation=90, font={'size': 18})
    plt.xticks([], [])
    # Add the scales beneath the xticks
    if x_axis_labels is not None:
        # print(f"x_axis_labels: {x_axis_labels}")
        
        num_labels = len(x_axis_labels)
        
        step_size = 1.0 / num_labels
        half_step_size = step_size / 2.0
        # use text to add the scale labels beneath the xticks
        for j, scale in enumerate(x_axis_labels):
            # plt.text(half_step_size + j * step_size, -0.35, scale, ha='center', va='top', fontsize=24, transform=plt.gca().transAxes)
            index_of_scale = j % len(scales)
            plt.text(half_step_size + j * step_size, -0.01, scale, ha='center', va='top', fontsize=38, transform=plt.gca().transAxes)
            plt.text(half_step_size + j * step_size, -0.07, ['10 Agents-20 Tasks', '10 Agents-50 Tasks', '20 Agents-100 Tasks', '40 Agents-200 Tasks'][index_of_scale], ha='center', va='top', fontsize=38, transform=plt.gca().transAxes)
    if ylabel_name in data_labels:
        plt.ylabel(data_labels[ylabel_name], fontsize=38)
    else:
        plt.ylabel(ylabel_name, fontsize=38)
    # set yticks to be 40pt
    plt.yticks(fontsize=38)
    handles, labels = plt.gca().get_legend_handles_labels()
    # print(handles)
    # print(labels)
    # remove duplicates by converting to dict and back to list
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1, 1))
    return plt

def get_bar_plots():
    data = get_data()
    # pp.pprint(data)
    
    # for each of scale and sensitivity,
    for i, (title, data_set) in enumerate(data.items()):
        # for each data_type [makespan, infeasible, makespan_optimality_rate, makespan_optimality_gap]
        for j, (data_type, data_type_set) in enumerate(data_set.items()):
            # scale is the x axis
            normalized = False
            if 'optimality_rate' in data_type or 'feasible' in data_type:
                normalized = True
                
            # with each scale, we have labels -> [mean, standard_deviation]
            plt = make_bar_plots(data_type_set, i, title=data_type, xlabel='Problem Set', ylabel_name=data_type, log=False, normalized=normalized)          
            plt.savefig(f"figures/meta_analysis/{title}_{data_type}_seedwise.png", bbox_inches='tight')

def get_time_data(): 
    # do the same with time
    time_data = {}
    for j, scale in enumerate(scale_data_set):
        # location is in final_results/time__{problem_set}.txt, final_results/time__simultaneous_{problem_set}.txt, final_results/time__simultaneous_{problem_set}_batch_8.txt
        problem_set = scale.split('/')[-1]
        locations = [
            os.path.join('final_results', f"time__{problem_set}.txt"),
            os.path.join('final_results', f"time__simultaneous_{problem_set}.txt"),
            # os.path.join('final_results', f"time__simultaneous_{problem_set}_batch_8.txt"),
        ]
        prefixes = [
            "",
            "simultaneous_",
            # "simultaneous_"
        ]
        postfixes = [
            "",
            "",
            # "_batch_8"
        ]
        xlabel = scales[j]
        time_data[xlabel] = {}
        for i, location in enumerate(locations):
            with open(location, 'r') as f:
                lines = f.readlines()
                if len(lines) == 0:
                    print(f"File {location} is empty")
                    continue
                for line in lines:    
                    if line.startswith('#') or line == '\n':
                        continue
                    key, data = line.strip().split(':')
                    # if key not in label_keys:
                    #     print(f"Key {key} not in label_keys")
                    if f"{prefixes[i]}{key}{postfixes[i]}" not in label_keys:
                        continue
                    mean, standard_deviation = data.split(',')
                    mean = float(mean)
                    standard_deviation = float(standard_deviation)
                    time_data[xlabel][f"{prefixes[i]}{key}{postfixes[i]}"] = [mean, standard_deviation]
    return time_data

def get_time_plots():
    time_data = get_time_data()
    # pp.pprint(time_data)
    
    # make a bar plot for time
    plt = make_bar_plots(time_data, 0, title='Time vs Problem Set', xlabel='Problem Set', ylabel_name='Time (s)', log=True)
    plt.savefig(f"figures/meta_analysis/time_seedwise.png", bbox_inches='tight')
            
            
def get_entropy_and_training_stability():
    data = get_data()
    
    # get the entropy and training stability data
    entropy_data = {}
    training_stability_data = {}
    
            
def meta_policy_analysis(meta_policies, data_set):
    # create a line graph for the meta policies
    # using the same data as the bar plots, for makespan
    for data_type in ['makespan', 'makespan_optimality_rate']:
        d_set = data_set['Scale'][data_type]
        pp.pprint(d_set)
        plt.figure(figsize=(20, 10))
        plt.rcParams.update({'font.size': 24})
        # colors = plt.cm.rainbow(np.linspace(0, 1, len(meta_policies)))
        colors20 = plt.colormaps['tab20'].colors
        colors20b = plt.colormaps['tab20b'].colors
        # remove indices 1, 4,, 9, 15
        colors = colors20[:1] + colors20[2:4] + colors20b[4:8] + colors20b[9:14] + colors20b[17:19] + colors20[6:7]
        
        # colors = colors20[0, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 16, 17, 19]
        colors = colors20[18:19] + colors20[2:8] + colors20[10:15] + colors20[16:18] + colors20[0:1]
        
        for i, meta_policy in enumerate(meta_policies):
            # get the data for the meta policy
            data = {}
            for scale in d_set.keys():
                data[scale] = d_set[scale][meta_policy]
                
            # make a line graph
            plt.plot(data.keys(), [d[0] for d in data.values()], label=label_keys[meta_policy], color=colors[i], linewidth=4)
            # add error bars
            plt.errorbar(data.keys(), [d[0] for d in data.values()], yerr=[[d[0] - d[1] for d in data.values()]], lw=2, capsize=5, capthick=2, color=colors[i])
            # include the mean +/- std on top of the bar as text
            # plt.text(j + i * width, [d[0] + d[1] + 0.02 for d in data.values()], label_keys[meta_policy], ha='center', va='bottom', fontsize=8, rotation=90)

        # label
        plt.title('Meta-Policy Analysis')
        plt.xlabel('Problem Set')
        plt.ylabel(data_labels[data_type])
        # plt.xticks(rotation=45)
        # add legend
        handles, labels = plt.gca().get_legend_handles_labels()
        # print(handles)
        # print(labels)
        # remove duplicates by converting to dict and back to list
        by_label = dict(zip(labels, handles))
        plt.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1, 1))
        # change font to be bigger
        plt.savefig(f"figures/meta_analysis/meta_policy_{data_type}_{meta_policies[0]}_seedwise.png", bbox_inches='tight')
        
def get_meta_plots():
    pass
    # # Meta-policy Analysis
    # # A line graph for Res-HGT-Edge and Sim-Res-HGT-Edge and Sim-HGT-Edge_batch_8
    # main_policy_name = 'hgt_edge_resnet'
    # meta_policies = [
    #     'hgt_edge_resnet',
    #     'simultaneous_hgt_edge_resnet',
    #     # 'simultaneous_hgt_edge_resnet_batch_8'
    # ]
    
    
    # meta_policy_analysis(['hgt', 'simultaneous_hgt'], get_data())
    # meta_policy_analysis(['hgt_edge', 'simultaneous_hgt_edge'], get_data())
    # meta_policy_analysis(meta_policies, get_data())
    
    # # meta_policy_analysis(['hgt', 'simultaneous_hgt', 'simultaneous_hgt_batch_8'], get_data())
    # # meta_policy_analysis(['hgt_edge', 'simultaneous_hgt_edge', 'simultaneous_hgt_edge_batch_8'], get_data())
    
def get_meta_analysis():
    # Actual Meta-Policy Performance:
    data = get_data()
    # pp.pprint(data)
    # meta_policies = {
    #     'hgt': ['hgt', 'simultaneous_hgt', 'simultaneous_hgt_batch_8'],
    #     'hgt_edge': ['hgt_edge', 'simultaneous_hgt_edge', 'simultaneous_hgt_edge_batch_8'],
    #     'hgt_edge_resnet': ['hgt_edge_resnet', 'simultaneous_hgt_edge_resnet', 'simultaneous_hgt_edge_resnet_batch_8'],
    # }
    
    meta_policies = {
        'hgt': ['hgt', 'simultaneous_hgt'],
        'hgt_edge': ['hgt_edge', 'simultaneous_hgt_edge'],
        'hgt_edge_resnet': ['hgt_edge_resnet', 'simultaneous_hgt_edge_resnet'],
    }
    
    scale_to_policy = {
        # '10R20T': 1, # simultaneous_hgt_batch_8
        # '10R50T': 1, # simultaneous_hgt_batch_8
        # '20R100T': 1, # hgt
        # '40R200T': 0, # hgt
        'Small': 1, # simultaneous_hgt_batch_8
        'Medium': 1, # simultaneous_hgt_batch_8
        'Large': 1, # hgt
        'XLarge': 0, # hgt
    }
    
    meta_scores = {} # scale -> policy -> [mean, std]
    
    for data_type in data_types:
        meta_scores[data_type] = {}
        for scale in scales:
            meta_scores[data_type][scale] = {}
            for policy in meta_policies.keys():
                meta_policy = meta_policies[policy][scale_to_policy[scale]]
                # print(policy, scale, meta_policy)
                meta_policy_performance = data['Scale'][data_type][scale][meta_policy]
                
                pp.pprint(meta_policy_performance)
                meta_scores[data_type][scale][policy] = meta_policy_performance

    # for scale in scales:
    #     for policy in meta_policies.keys():
    
    # copy all the data to the meta_scores where the policy key does not contain 'hgt'
    for data_type in data_types:
        for scale in scales:
            for policy in data['Scale'][data_type][scale].keys():
                if 'hgt' in policy:
                    continue
            
                # copy the data to the meta_scores
                meta_scores[data_type][scale][policy] = data['Scale'][data_type][scale][policy]
                # print(f"{policy}: {data}")
    pp.pprint(meta_scores)
    
    # get meta bar plots
    # for each data_type in ['makespan', 'infeasible', 'makespan_optimality_rate', 'makespan_optimality_gap']
    for data_type, data_type_set in meta_scores.items():
        # make a bar plot
        plt = make_bar_plots(meta_scores[data_type], 0, title=data_type, xlabel='Problem Set', ylabel_name=data_type, log=False, normalized=('optimality_rate' in data_type))
        plt.savefig(f"figures/meta_analysis/meta_policy_{data_type}_seedwise.png", bbox_inches='tight')
    
    # data['Scale'][data_type][scale][policy]
    # now do the same for time
    meta_time_data = {}
    time_data = get_time_data()
    
    # pp.pprint(time_data)
    for scale in scales:
        meta_time_data[scale] = {}        
        for policy in meta_policies.keys():
            meta_policy = meta_policies[policy][scale_to_policy[scale]]
            print(f"Time {policy} {scale} {meta_policy} - {time_data.keys()}")
            meta_time_data[scale][policy] = time_data[scale][meta_policy]
        
        for policy in time_data[scale].keys():
            if 'hgt' in policy:
                continue
            if policy in ['gen_random_10', 'gen_edf_10']:
                continue
            meta_time_data[scale][policy] = time_data[scale][policy]
    
    plt = make_bar_plots(meta_time_data, 0, title='Time vs Problem Set', xlabel='Problem Set', ylabel_name='Time (s)', log=True)
    plt.savefig(f"figures/meta_analysis/meta_time_seedwise.png", bbox_inches='tight')
    
    
def get_optimality_rate_vs_time_given_set(data, time_data, label_order, location="figures/meta_analysis/dummy"):
    optimality_rate_data = data['Scale']['infeasible']
    
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(label_order)))
    colors20 = plt.colormaps['tab20'].colors
    colors20b = plt.colormaps['tab20b'].colors
    # remove indices 1, 4,, 9, 15
    colors = colors20[:1] + colors20[2:4] + colors20b[4:8] + colors20b[9:14] + colors20b[17:19] + colors20[6:7]
    
    # colors = colors20[0, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 16, 17, 19]
    colors = colors20[18:19] + colors20[2:8] + colors20[10:15] + colors20[16:18] + colors20[0:1]
    
    key = ['S', 'M', 'L', 'XL']
    
    # add a secondary key for small, medium, large, xlarge
    # small -> point
    # medium -> triangle
    # large -> square
    # xlarge -> star
    scale_keys = ['o', 'X', '^', 's'] #, 'p']
    
    plts = []
    # iterate through scales
    for last_scale_id in range(1, len(scales) + 1):
        
        plt.cla()
        fig, ax = plt.subplots(figsize=(20, 15))
        # x in log scale
        # plt.xscale('log')
        plt.rcParams.update({'font.size': 24})
        for i, label in enumerate(label_order):
            if label not in optimality_rate_data[scales[0]].keys():
                continue
            # get the optimality rate data
            optimality_rate = []
            time = []
            
            points = []
            
            for j, scale in zip(range(0, last_scale_id), scales[0:last_scale_id]):
                if label not in optimality_rate_data[scale].keys():
                    continue
                mean_optimality_rate, _ = optimality_rate_data[scale][label]
                if scale not in time_data.keys() or label not in time_data[scale].keys():
                    continue
                mean_time, _ = time_data[scale][label]
                optimality_rate.append(mean_optimality_rate)
                # time.append(mean_time / poly_scale[j])
                
                # the time divided by |A|^|T| in log scale is
                # log_time = np.log10(mean_time) - (task_sizes[j] * np.log10(agent_sizes[j]))
                # log_time = np.log10(mean_time) - (log_factorial(task_sizes[j]))
                # log_time = np.log10(mean_time) - (2*np.log10(task_sizes[j]) + np.log10(agent_sizes[j]))
                # time.append(log_time)
                
                log_time = np.log10(task_sizes[j] * np.log10(agent_sizes[j]))
                time.append(log_time)
                
                # points.append([mean_time, mean_optimality_rate])
                points.append([log_time, mean_optimality_rate])
            # from matplotlib.patches import Ellipse
            
            points = np.array(points)
            # mean = np.mean(points, axis=0)
            # std_x = 3*np.std(points[:, 0])
            # std_y = 3*np.std(points[:, 1])
            
            # # create a region of interest (ROI) ellipse
            # ellipse = Ellipse(mean, width=2*std_x, height=2*std_y, edgecolor=colors[i], facecolor=colors[i], linestyle='--', linewidth=2, alpha=0.2)
            # ax.add_patch(ellipse)
            
            if len(optimality_rate) == 0 or len(time) == 0:
                continue
            
            # plt.scatter(time, optimality_rate, label=label_keys[label], color=colors[i], s=100)
            # replace the above code with different markers for different scales
            # for j, t, o in zip(range(len(time)), time, optimality_rate):
            #     plt.scatter(t, o, label=label_keys[label], color=colors[i], s=100, marker=scale_keys[j])
            # error bars
            xerr = [d[1] / poly_scale[j] for j, scale in zip(range(0, last_scale_id), scales[0:last_scale_id]) if label in time_data[scale].keys() for d in [time_data[scale][label]]]
            yerr = [d[1] for scale in scales if label in optimality_rate_data[scale].keys() for d in [optimality_rate_data[scale][label]]]

            yerr_max = [min(100.0 - d[0], d[1]) for d in [optimality_rate_data[scale][label] for scale in scales[0:last_scale_id] if label in optimality_rate_data[scale].keys()]]
            yerr_min = [min(d[0] - 0.0, d[1]) for d in [optimality_rate_data[scale][label] for scale in scales[0:last_scale_id] if label in optimality_rate_data[scale].keys()]]
            
            for j, t, o in zip(range(0, last_scale_id), time[:last_scale_id], optimality_rate[:last_scale_id]):
                # print(yerr_min[j], yerr_max[j], xerr[j])
                # plt.errorbar(t, o, xerr=xerr[j], yerr=[[yerr_min[j]], [yerr_max[j]]], fmt=scale_keys[j], color=colors[i], capsize=5, capthick=2)
                # make the points larger
                plt.errorbar(t, o, xerr=xerr[j], yerr=[[yerr_min[j]], [yerr_max[j]]], fmt=scale_keys[j], color=colors[i], capsize=5, capthick=2, markersize=14)
            # plt.errorbar(time, optimality_rate, xerr=xerr, yerr=yerr, fmt='o', color=colors[i], capsize=5, capthick=2)
            # plt.errorbar(time, optimality_rate, xerr=xerr, yerr=[yerr_min, yerr_max], fmt='o', color=colors[i], capsize=5, capthick=2)
            # dashed line for all
            # plt.plot(time, optimality_rate, color=colors[i], linewidth=2, linestyle='->', alpha=0.5)
            # plt.arrow(time[0], optimality_rate[0], time[-1] - time[0], optimality_rate[-1] - optimality_rate[0], 
            #           color=colors[i], linewidth=2, linestyle='--',
            #           length_includes_head=True, head_width=0.5)
            
            # end 0.05th of the length before
            
            for j in range(last_scale_id-1):
                delta = [time[j+1] - time[j], optimality_rate[j+1] - optimality_rate[j]]
                
                # plt.quiver(time[j], optimality_rate[j], time[j+1] - time[j], optimality_rate[j+1] - optimality_rate[j],
                #               angles='xy', scale_units='xy', scale=1, color=colors[i], width=0.002, headwidth=5, alpha=0.5)
                plt.quiver(time[j], optimality_rate[j], delta[0]*0.95, delta[1]*0.95, 
                            angles='xy', scale_units='xy', scale=1, color=colors[i], width=0.002, headwidth=5, alpha=0.5)
            # plt.quiver(time[0], optimality_rate[0],time[-1] - time[0], optimality_rate[-1] - optimality_rate[0],
            #            angles='xy', scale_units='xy', scale=1, color=colors[i], width=0.002, headwidth=5, alpha=0.5)
            
            # add the keys on top of the points
        # The above code is using a for loop to iterate over the elements in the "key" variable. The
        # enumerate() function is used to get both the index (i) and the value (text) of each element in the
        # "key" variable.
            # for i, text in enumerate(key):
            #     plt.text(time[i], optimality_rate[i] - 0.02, text, ha='center', va='bottom', fontsize=10, rotation=0)
        
        # set y limits to [0, 100]
        plt.ylim(0, 105)
        # plt.xlim(-0.005, 0.05)
        # plt.title('Time vs Feasibility')
        plt.xlabel('Complexity $O(|A|^{|T|})$')
        plt.ylabel('Feasibility (%)')
        
        # xticks replace with 10^x instead of x
        ## get the x ticks
        x_ticks = plt.xticks()[0]
        # print(x_ticks)
        # use 2 decimal places
        x_ticks = [round(x, 2) for x in x_ticks]
        exp_x_ticks = ["$10^{" + str(round(x, 2)) + "}$" for x in x_ticks]
        # print(exp_x_ticks)
        plt.xticks(x_ticks, exp_x_ticks)
        
        # add legend
        # handles, labels = plt.gca().get_legend_handles_labels()
        # print(handles)
        # print(labels)
        # remove duplicates by converting to dict and back to list
        # by_label = dict(zip(labels, handles))
        # plt.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1, 1))
        
        # Legend should have 2 components, colors to denote the policy and markers (in black) to denote the scale
        
        # create a custom legend
        from matplotlib.lines import Line2D
        legend_elements = []
    
        # add section title to the legend, that is not a part of the legend and unindented
        legend_elements.append(Line2D([0], [0], color='w', label='', markerfacecolor='w', markersize=0, antialiased=True))
        legend_elements.append(Line2D([0], [0], color='w', label='', markerfacecolor='w', markersize=0, antialiased=True))
        
        for i, label in enumerate(label_order):
            if label not in optimality_rate_data[scales[0]].keys():
                continue
            legend_elements.append(Line2D([0], [0], marker='o', color='w', label=label_keys[label],
                                          markerfacecolor=colors[i], markersize=10))
    
        legend_elements.append(Line2D([0], [0], color='w', label='', markerfacecolor='w', markersize=0, antialiased=True))
        legend_elements.append(Line2D([0], [0], color='w', label='', markerfacecolor='w', markersize=0, antialiased=True))
        for j, scale in zip(range(0, last_scale_id), scales[0:last_scale_id]):
            legend_elements.append(Line2D([0], [0], marker=scale_keys[j], color='k', label=scales[j],
                                          markersize=10, linestyle='None'))
        plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1, 1.025))
        
        
        
        
        plt.savefig(f"{location}{last_scale_id}_seedwise.png", bbox_inches='tight')
    #     plts.append(plt)
    # return plts

def get_optimality_rate_vs_time():
        
    # Optimality Rate and Time  Piecewise Pareto Curve
    data = get_data()
    time_data = get_time_data()
    get_optimality_rate_vs_time_given_set(data, time_data, order, location="figures/meta_analysis/optimality_rate_vs_complexity_full_")
    
    # plt.cla()
    # plt = get_optimality_rate_vs_time_given_set(data, time_data, ['milp_solver'])
    # for i, p in enumerate(plt):
    #     p.savefig(f"figures/meta_analysis/optimality_rate_vs_time_0_milp_seedwise.png", bbox_inches='tight')
    
    # plt.cla()
    # plt = get_optimality_rate_vs_time_given_set(data, time_data, ['milp_solver', 'edf', 'improved_edf'])
    # plt.savefig(f"figures/meta_analysis/optimality_rate_vs_time_2_milp_edf_seedwise.png", bbox_inches='tight')
    
    # plt.cla()
    # plt = get_optimality_rate_vs_time_given_set(data, time_data, ['milp_solver', 'edf', 'improved_edf', 'gen_random', 'gen_edf', 'gen_random_3', 'gen_edf_3'])
    # plt.savefig(f"figures/meta_analysis/optimality_rate_vs_time_3_milp_edf_gen_seedwise.png", bbox_inches='tight')
    
    # plt.cla()
    # plt = get_optimality_rate_vs_time_given_set(data, time_data, ['milp_solver', 'edf', 'improved_edf', 'gen_random', 'gen_edf', 'gen_random_3', 'gen_edf_3', 'hetgat', 'hetgat_resnet'])
    # plt.savefig(f"figures/meta_analysis/optimality_rate_vs_time_4_milp_edf_gen_hetgat_seedwise.png", bbox_inches='tight')
    
    # plt.cla()
    # plt = get_optimality_rate_vs_time_given_set(data, time_data, ['milp_solver', 'edf', 'improved_edf', 'gen_random', 'gen_edf', 'gen_random_3', 'gen_edf_3', 'hetgat', 'hetgat_resnet', 'hgt_edge_resnet'])
    # plt.savefig(f"figures/meta_analysis/optimality_rate_vs_time_5_milp_edf_gen_hetgat_sim_seedwise.png", bbox_inches='tight')
    
    get_optimality_rate_vs_time_given_set(data, time_data, ['milp_solver', 'edf', 'improved_edf', 'gen_random', 'gen_edf', 'gen_random_3', 'gen_edf_3', 'hetgat', 'hetgat_resnet', 'hgt_edge_resnet', 'simultaneous_hgt_edge_resnet'], f"figures/meta_analysis/optimality_rate_vs_complexity_6_milp_edf_gen_hetgat_sim_resnet_")
    
    
    # optimality_rate_data = data['Scale']['makespan_optimality_rate']
    # plt.cla()
    # plt.figure(figsize=(20, 10))
    # # x in log scale
    # plt.xscale('log')
    # plt.rcParams.update({'font.size': 24})
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(order)))
    # for i, label in enumerate(order):
    #     if label not in optimality_rate_data[scales[0]].keys():
    #         continue
    #     # get the optimality rate data
    #     optimality_rate = []
    #     time = []
    #     for scale in scales:
    #         if label not in optimality_rate_data[scale].keys():
    #             continue
    #         mean_optimality_rate, _ = optimality_rate_data[scale][label]
    #         if scale not in time_data.keys() or label not in time_data[scale].keys():
    #             continue
    #         mean_time, _ = time_data[scale][label]
    #         optimality_rate.append(mean_optimality_rate)
    #         time.append(mean_time)
        
    #     if len(optimality_rate) == 0 or len(time) == 0:
    #         continue
        
    #     plt.scatter(time, optimality_rate, label=label_keys[label], color=colors[i], s=100)
    #     # error bars
    #     xerr = [d[1] for scale in scales if label in time_data[scale].keys() for d in [time_data[scale][label]]]
    #     yerr = [d[1] for scale in scales if label in optimality_rate_data[scale].keys() for d in [optimality_rate_data[scale][label]]]
    #     plt.errorbar(time, optimality_rate, xerr=xerr, yerr=yerr, fmt='o', color=colors[i], capsize=5, capthick=2)
    #     # fit it to a curve using scipy's curve_fit
    #     from scipy.optimize import curve_fit
    #     # Define a function to fit the data (e.g., exponential decay)
    #     def func(x, a, b, c):
    #         return a * x**3 + b * x + c
    #         # return a * x**5 + b * x**4 + c * x**3 + d * x**2 + e * x + f

    #     # Fit the curve
    #     try:
    #         popt, pcov = curve_fit(func, time, optimality_rate)
    #         # Generate fitted data
    #         fitted_time = np.linspace(min(time), max(time), 500)
    #         fitted_optimality_rate = func(fitted_time, *popt)
    #         # Plot the fitted curve
    #         plt.plot(fitted_time, fitted_optimality_rate, color=colors[i], linewidth=2, linestyle='--', alpha=0.5)
    #     except Exception as e:
    #         print(f"Curve fitting failed for {label}: {e}")
        
        
    # plt.title('Time vs Optimality Rate')
    # plt.xlabel('Time (s)')
    # plt.ylabel('Optimality Rate')
    # # add legend
    # handles, labels = plt.gca().get_legend_handles_labels()
    # # print(handles)
    # # print(labels)
    # # remove duplicates by converting to dict and back to list
    # by_label = dict(zip(labels, handles))
    # plt.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1, 1))
    # plt.savefig(f"figures/meta_analysis/optimality_rate_vs_time_2_seedwise.png", bbox_inches='tight')
    
    
    # # Pareto Curve with Time as bar graph and Optimality Rate as line graph
    # plt.cla()
    # plt.figure(figsize=(20, 10))
    # plt.rcParams.update({'font.size': 24})
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(order)))
    # # y1 in log scale for time
    # # y2 in linear scale for optimality rate
    # ax1 = plt.gca()
    # ax1.set_yscale('log')
    # a2x2 = ax1.twinx()
    
def log_factorial(x):
    return np.log(x) + log_factorial(x-1) if x > 1 else 1
    

def get_2d_pareto_curve():
    from matplotlib.lines import Line2D
    data = get_data()
    time_data = get_time_data()
    # 2D Pareto Curve with x as time and y as optimality/feasibility rate
    
    optimality_rate_data = data['Scale']['makespan'] # For starters let's use infeasible as the optimality rate, since this is a better metric
    
    scale_colors = ['red', 'green', 'blue', 'orange']
    
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(order)))
    colors20 = plt.colormaps['tab20'].colors
    colors20b = plt.colormaps['tab20b'].colors
    # remove indices 1, 4,, 9, 15
    colors = colors20[:1] + colors20[2:4] + colors20b[4:8] + colors20b[9:14] + colors20b[17:19] + colors20[6:7]
    
    # colors = colors20[0, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 16, 17, 19]
    colors = colors20[18:19] + colors20[2:8] + colors20[10:15] + colors20[16:18] + colors20[0:1]
    
    scale_keys = ['o', 'X', '^', 's'] #, 'p']
    
    fig = plt.figure(figsize=(20, 10))
    ax = fig.add_subplot()
    plt.xscale('log')
    plt.yscale('linear')
    plt.rcParams.update({'font.size': 24})

    for j, scale_name in enumerate(scales):
        # get the optimality rate data
        optimality_rate = []
        time = []
        for i, label in enumerate(order):
            if label not in optimality_rate_data[scales[0]].keys():
                continue
            if label not in optimality_rate_data[scale_name].keys():
                continue
            mean_optimality_rate, std_optimality_rate = optimality_rate_data[scale_name][label]
            if scale_name not in time_data.keys() or label not in time_data[scale_name].keys():
                continue
            if label not in optimality_rate_data[scales[0]].keys():
                continue
            # get the optimality rate data
            mean_time, std_time = time_data[scale_name][label]
            optimality_rate.append(mean_optimality_rate)
            time.append(mean_time)
            
            print(f"{label} {scale_name} {mean_time} {mean_optimality_rate} - {j}")
            
            plt.scatter(mean_time, mean_optimality_rate, label=label_keys[label], color=colors[i], s=100)
            # error bars
            plt.errorbar(mean_time, mean_optimality_rate, xerr=std_time, yerr=std_optimality_rate, fmt=scale_keys[j], color=colors[i], capsize=5, capthick=2)
    
        print(time)
        print(optimality_rate)
        # connect order
        for i in range(len(order)-1):
            index_1 = i
            index_2 = i + 1
            plt.plot([time[index_1], time[index_2]], [optimality_rate[index_1], optimality_rate[index_2]], color=scale_colors[j], linewidth=2, linestyle='--', alpha=0.5)    
    # labels
    plt.title('Time vs Optimality Rate', fontsize=24)
    plt.xlabel('Log Time (s)', fontsize=24)
    plt.ylabel('Optimality Rate', fontsize=24)
    # set font to 20 for ticks and labels
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)

    # plt.ylim(0, 105)

    # set x axis label to 20pt
        
    # add legend
    legend_elements = []
    for j, scale in zip(range(0, len(scales)), scales):
        # Line2D([0], [0], marker=scale_keys[j], color=scale_colors[j], label=scale,
                                    #   markersize=10, linestyle='--')
        plt.plot([0], [0], marker=scale_keys[j], color=scale_colors[j], label=scale,
                 markersize=10, linestyle='--')
    # expand the legend to include the scale
    handles, labels = plt.gca().get_legend_handles_labels()
    print(labels)
    # print(handles)
    # print(labels)
    # remove duplicates by converting to dict and back to list
    by_label = dict(zip(labels, handles))
    
    plt.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1, 1))
    
    
    
    plt.savefig(f"figures/meta_analysis/pareto_optimality_rate_vs_time_seedwise.png", bbox_inches='tight')
    
    
def get_3d_pareto_curve(connect_scale=False):
    import math
    from mpl_toolkits.mplot3d import Axes3D
    # 3D Pareto Curve with x as time, y as optimality/feasibility rate and z as scale
    
    # Scale increases by k! where k = |T| = {20, 50, 100, 200}
    # Time is in Log scale
    # Optimality and Feasibility Rate is in Linear scale
    data = get_data()
    time_data = get_time_data()
    
    # optimality_rate_data = data['Scale']['makespan'] # For starters let's use infeasible as the optimality rate, since this is a better metric
    optimality_rate_data = data['Scale']['infeasible']
    
    scales_complexity = [k*np.log(x)+log_factorial(k) for k, x in zip(task_sizes, agent_sizes)]
    
    # scales_complexity = [log_factorial(x) for x in scales_complexity]
    # scales_complexity = [math.factorial((x)) for x in scales_complexity]
    # scale colors are four main colors = ['red', 'green', 'blue', 'orange']
    scale_colors = ['red', 'green', 'blue', 'orange']
    
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(order)))
    # color groups
    # groups are ['milp', 'heuristics', 'metaheuristics', 'sequential methods', 'ablations', 'Ours']
    # colors = plt.cm.rainbow(np.linspace(0, 1, len(order) + 5))
    # colors = 
    colors20 = plt.colormaps['tab20'].colors
    colors20b = plt.colormaps['tab20b'].colors
    # remove indices 1, 4,, 9, 15
    colors = colors20[:1] + colors20[2:4] + colors20b[4:8] + colors20b[9:14] + colors20b[17:19] + colors20[6:7]
    
    # colors = colors20[0, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 16, 17, 19]
    colors = colors20[18:19] + colors20[2:8] + colors20[10:15] + colors20[16:18] + colors20[0:1]
    
    key = ['S', 'M', 'L', 'XL']
    scale_keys = ['o', 'X', '^', 's'] #, 'p']
    
    # at each scale, connect all tje points with lines along the different policies
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(projection='3d')
    # ax.set_box_aspect([1, 1, 1])  # aspect ratio is 1:1:1
    print("=="*20)
    x_log = [] 
    # z_log = []
    for j, scale_name in enumerate(scales):
        optimality_rate = []
        time = []
        scale = []
        for i, label in enumerate(order):
            if label not in optimality_rate_data[scale_name].keys():
                continue
            mean_optimality_rate, std_optimality_rate = optimality_rate_data[scale_name][label]
            if scale_name not in time_data.keys() or label not in time_data[scale_name].keys():
                continue
            if label not in optimality_rate_data[scales[0]].keys():
                continue
            # get the optimality rate data
            mean_time, std_time = time_data[scale_name][label]
            optimality_rate.append(mean_optimality_rate)
            time.append([mean_time, std_time])
            scale.append(scales_complexity[j])
            
            x_log.append(math.log10(mean_time))
            # z_log.append(math.log10(scales_complexity[j]))
            print(f"{label} {scale_name} {mean_time} {mean_optimality_rate} - {j} {scales_complexity[j]}")
            
            
            ax.scatter(np.log10(mean_time), mean_optimality_rate, scales_complexity[j], label=label_keys[label], color=colors[i])
            # error bars
            # ax.errorbar(mean_time, mean_optimality_rate, scales_complexity[j], xerr=std_time, yerr=std_optimality_rate, fmt='o', color=colors[i], capsize=5, capthick=2)

        # create an edge between the points in the same scale
        
        # connected_edges = ['simultaneous_hgt_edge_resnet', 'edf', 'improved_edf', 'gen_random', 'hgt_edge_resnet', 'milp_solver',  'gen_edf', 'gen_random_3', 'gen_edf_3']
        # for i in range(len(connected_edges)-1):
        #     index_1 = order.index(connected_edges[i])
        #     index_2 = order.index(connected_edges[i+1])
        #     ax.plot([np.log10(time[index_1][0]), np.log10(time[index_2][0])], [optimality_rate[index_1], optimality_rate[index_2]], [scales_complexity[j], scales_complexity[j]], color=scale_colors[j], linewidth=2, linestyle='--', alpha=0.5)
        
        if connect_scale:
            # connect order
            for i in range(len(order)-1):
                index_1 = order.index(order[i])
                index_2 = order.index(order[i+1])
                ax.plot([np.log10(time[index_1][0]), np.log10(time[index_2][0])], [optimality_rate[index_1], optimality_rate[index_2]], [scales_complexity[j], scales_complexity[j]], color=scale_colors[j], linewidth=2, linestyle='--', alpha=0.5)

    if not connect_scale:
        # connect the points in the same policy
        for i, policy in enumerate(order):
            optimality_rate = []
            time = []
            scale = []
            
            x_log = []
            for j, scale_name in enumerate(scales):
                if policy not in optimality_rate_data[scale_name].keys():
                    continue
                mean_optimality_rate, std_optimality_rate = optimality_rate_data[scale_name][policy]
                if scale_name not in time_data.keys() or policy not in time_data[scale_name].keys():
                    continue
                if policy not in optimality_rate_data[scales[0]].keys():
                    continue
                # get the optimality rate data
                mean_time, std_time = time_data[scale_name][policy]
                optimality_rate.append(mean_optimality_rate)
                time.append([mean_time, std_time])
                scale.append(scales_complexity[j])
                
                x_log.append(math.log10(mean_time))
            
            if policy in ['simultaneous_hgt_edge_resnet']:
                linestyle = '-'
            else:
                linestyle = '--'
            for j in range(len(scales)-1):
                index_1 = j
                index_2 = j + 1
                ax.plot([x_log[index_1], x_log[index_2]], [optimality_rate[index_1], optimality_rate[index_2]], [scales_complexity[index_1], scales_complexity[index_2]], color=colors[i], linewidth=2, linestyle=linestyle, alpha=0.5)
            
            
    # y axis in log scale
    # ax.set_yscale('log')
    # ax.set_zscale('log')
    ax.set_xlabel('Time log (s)')
    ax.set_ylabel('Feasibility Rate (%)')
    # ax.set_zlabel('Scale log O(|A|^|T|)')
    ax.set_zlabel('$\log_{10}(\mathcal{O}(|A|^{|T|}*|T|!))$')
    
    x_ticks = np.arange(0, np.ceil(np.max(x_log)) + 1)
    ax.set_xticklabels([f"$10^{int(x)}$" for x in x_ticks])
    
    print(scales_complexity)
    # z_ticks = np.arange(0, np.ceil(np.max(scales_complexity)) + 1, 20)
    # ax.set_zticks(z_ticks)
    z_ticks = np.arange(0, np.ceil(np.max(scales_complexity)) + 1, 200)
    ax.set_zticks(z_ticks)
    ax.set_zticklabels(["$10^{"+str(int(x))+"}$" for x in z_ticks])
    # isometric view
    # ax.view_init(elev=20,)

    # add legend
    handles, labels = ax.get_legend_handles_labels()
    # print(handles)
    # print(labels)
    # remove duplicates by converting to dict and back to list
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper left', bbox_to_anchor=(1, 1))
    # add a title
    
    # rotate the view and save the figure in a loop
    for angle in range(0, 360, 10):
        print(angle)
        ax.view_init(20, angle)
        # plt.draw()
        # plt.pause(.001)
        # save the figure
        plt.savefig(f"figures/meta_analysis/3d_pareto_curve_{angle}_seedwise.png", bbox_inches='tight')
    
    plt.savefig(f"figures/meta_analysis/3d_pareto_curve_seedwise.png", bbox_inches='tight')   
            
    
def save_data_as_table():
    data = get_data()
    # pp.pprint(data)
    time_data = get_time_data()
    # pp.pprint(time_data)
    
    for i, exp_type in enumerate(['Scale', 'Sensitivity']):
        # create a csv file with rows as scales, subrows for [optimality rate, feasibility percentage, computation time] and columns as policy type
        # for each scale, we have labels -> [mean, standard_deviation], in mean +/- std format
        data_set = data[exp_type]
        # data type, scale, policy, (mean, std)
        
        # swap inner and outer keys, such that it is scale, data_type, policy
        d = {}
        for d_type in data_set.keys():
            for scale in data_set[d_type].keys():
                if scale not in d.keys():
                    d[scale] = {}
                if d_type not in d[scale].keys():
                    d[scale][d_type] = {}
                for policy in data_set[d_type][scale].keys():
                    if policy not in d[scale][d_type].keys():
                        d[scale][d_type][policy] = []
                    mean, std = data_set[d_type][scale][policy]
                    d[scale][d_type][policy].append([mean, std])
        
        file_name = f"figures/meta_analysis/{exp_type}_seedwise.csv"
        with open(file_name, 'w') as f:
            # write the header
            # o
            f.write(f"Scale,Data Type,{','.join([label_keys[o] for o in order])}\n")
            for scale in d.keys():
                for d_type in d[scale].keys():
                    txt = ""
                    for o in order:
                        if o not in d[scale][d_type].keys():
                            continue
                        mean, std = d[scale][d_type][o][0]
                        if d_type in ['optimality_rate']:
                            mean = 100 * mean
                            std = 100 * std
                        value = f"{mean:.2f} ± {std:.2f}"
                        txt += f"{value},"
                    # remove the last comma
                    txt = txt[:-1]
                    f.write(f"{scale},{data_labels[d_type]},{txt}\n")
                    # for policy in d[scale][d_type].keys():
                    #     mean, std = d[scale][d_type][policy][0]
                    #     value = f"{mean:.2f} ± {std:.2f}"
                    #     f.write(f"{scale},{d_type},{policy},{value}\n")
                    
                if i > 0:
                    continue
                # For the scale, also add the time data
                # pp.pprint(time_data[scale])
                txt = ""
                for o in order:
                    if o not in time_data[scale].keys():
                        continue
                    mean, std = time_data[scale][o]
                    value = f"{mean:.2f} ± {std:.2f}"
                    txt += f"{value},"
                # remove the last comma
                txt = txt[:-1]
                f.write(f"{scale},Computation Time,{txt}\n")    
                
        print(f"Data written to {file_name}")
                        

if __name__ == "__main__":
    get_bar_plots()        
    get_time_plots()
    # get_meta_plots()
    # get_meta_analysis()
    # get_optimality_rate_vs_time()
    # get_2d_pareto_curve()
    # get_3d_pareto_curve(connect_scale=False)
    #                     )
    save_data_as_table()
    
