
import os

import numpy as np
import matplotlib.pyplot as plt

models = ['milp', 'edf', 'improved_edf', 'hetgat', 'hetgat_resnet', 'hgt', 'hgt_edge', 'hgt_edge_resnet', 'hgt_edge_resnet_bb']

problem_sets = {
    "r2t5": "problem_set_r2_t5_s10_f30_w50_euc_200_test",
    "r3t10": "problem_set_r3_t10_s10_f50_w50_euc_200_test",
    "r3t15": "problem_set_r3_t15_s10_f50_w50_euc_200_test",
    "r5t20": "problem_set_r5_t20_s10_f30_w50_euc_200_test",
    "r5t50": "problem_set_r5_t50_s10_f50_w50_euc_200_test",
    "r10t100": "problem_set_r10_t100_s10_f80_w25_euc_200_test",
    # "r50t200": "problem_set_r50_t200_s10_f80_w10_euc_200_test",
    # "r50t500": "problem_set_r50_t500_s10_f80_w25_euc_200_test",
    # "problem_set_r3_t10_s10_f50_w50_euc_200_test",
    # "problem_set_r3_t15_s10_f50_w50_euc_200_test",
    # "problem_set_r5_t20_s10_f30_w50_euc_200_test",
    # "problem_set_r5_t50_s10_f50_w50_euc_200_test",
    # "problem_set_r10_t100_s10_f80_w25_euc_200_test",
    # "problem_set_r50_t200_s10_f80_w10_euc_200_test",
    # "problem_set_r50_t500_s10_f80_w25_euc_200_test",
}

file_path = "results/evals/"
file_prefix = "evaluate_time__evaluation"
file_suffix = "time.txt"

# model_data = {}
model_means = {}
model_stds = {}
for model in models:
    # model_data = []
    model_means[model] = []
    model_stds[model] = []
    for name, problem_set in problem_sets.items():
        file_name = os.path.join(file_path, f"{file_prefix}__{problem_set}__{model}__{file_suffix}")
        with open(file_name, 'r') as f:
            data = f.read()
            data = data.split("\n")
            data = [float(x) for x in data if x]
            # print(f"Model: {model}, Problem Set: {problem_set}, Mean: {np.mean(data)}, Median: {np.median(data)}, Std: {np.std(data)}")
            # print(f"Data: {data}")
            # print()
            model_means[model].append(np.mean(data))
            model_stds[model].append(np.std(data))
        
        
print(model_means)
for model, mean in model_means.items():
    std = model_stds[model]
    plt.errorbar(problem_sets.keys(), mean, yerr=std, label=model, capsize=5)
# make it log scale
plt.yscale('log')
# plt.yticks(np.arange(len(problem_sets.keys())), problem_sets.keys())
# plt.legend()

plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
# plt.show()
plt.savefig("figures/time_figures.png", bbox_inches='tight')