import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import linregress, t
import seaborn as sns


np.random.seed(0)

# tasks: opf, fault detection, state estimation, transient prediction, lmp prediction
# metrics: pass_rate, optimality gap, type_acc, loc_acc, nrmse, r^2, nrmse, r^2, nrmse, r^2
gla_scaling_performance = np.asarray([
    [   # 1k
        [0.801, 0.272, 0.631, 0.281, 11.034, 0.932, 0.0, 0.0, 0.456, 0.968],
        [0.712, 0.421, 0.662, 0.122, 16.721, 0.935, 0.0, 0.0, 0.742, 0.959],
        [0.654, 0.234, 0.644, 0.147, 18.298, 0.945, 0.0, 0.0, 0.951, 0.977],
        [0.567, 0.797, 0.639, 0.069, 19.087, 0.911, 0.0, 0.0, 3.080, 0.897],
        [0.603, 0.308, 0.533, 0.033, 19.119, 0.962, 0.0, 0.0, 1.296, 0.953],
    ],
    [   # 2k
        [0.834, 0.201, 0.734, 0.362, 2.902, 0.974, 0.0, 0.0, 0.334, 0.973],
        [0.732, 0.372, 0.742, 0.161, 3.028, 0.977, 0.0, 0.0, 0.541, 0.968],
        [0.708, 0.196, 0.781, 0.225, 3.826, 0.978, 0.0, 0.0, 0.754, 0.982],
        [0.633, 0.547, 0.694, 0.036, 14.733, 0.952, 0.0, 0.0, 2.249, 0.949],
        [0.621, 0.232, 0.642, 0.011, 11.102, 0.958, 0.0, 0.0, 0.948, 0.954],
    ],
    [   # 4k
        [0.938, 0.133, 0.841, 0.622, 1.916, 0.987, 0.0, 0.0, 0.194, 0.981],
        [0.842, 0.341, 0.742, 0.351, 2.662, 0.986, 0.0, 0.0, 0.287, 0.973],
        [0.824, 0.132, 0.805, 0.293, 2.147, 0.983, 0.0, 0.0, 0.556, 0.987],
        [0.706, 0.373, 0.741, 0.172, 8.012, 0.969, 0.0, 0.0, 1.638, 0.969],
        [0.712, 0.203, 0.795, 0.193, 4.302, 0.971, 0.0, 0.0, 0.643, 0.968],
    ],
    [   # 8k
        [0.971, 0.094, 0.871, 0.845, 1.455, 0.995, 0.0, 0.0, 0.098, 0.987],
        [0.932, 0.198, 0.812, 0.562, 1.954, 0.985, 0.0, 0.0, 0.047, 0.982],
        [0.925, 0.105, 0.814, 0.691, 1.219, 0.986, 0.0, 0.0, 0.448, 0.990],
        [0.881, 0.272, 0.783, 0.489, 6.095, 0.973, 0.0, 0.0, 1.460, 0.974],
        [0.883, 0.131, 0.811, 0.403, 2.326, 0.981, 0.0, 0.0, 0.467, 0.976],
    ],
    [   # 16k
        [0.991, 0.037, 0.894, 0.851, 1.299, 1.00, 0.0, 0.0, 0.008, 0.992],
        [0.982, 0.134, 0.811, 0.724, 1.528, 0.994, 0.0, 0.0, 0.009, 0.995],
        [0.964, 0.073, 0.852, 0.753, 1.176, 0.995, 0.0, 0.0, 0.349, 0.987],
        [0.965, 0.113, 0.854, 0.605, 4.202, 0.989, 0.0, 0.0, 1.237, 0.981],
        [0.961, 0.071, 0.801, 0.544, 1.429, 0.988, 0.0, 0.0, 0.312, 0.993],
    ],
    [   # 32k
        [1.000, 0.014, 0.963, 0.945, 1.10, 1.00, 0.339, 0.991, 0.012, 1.000],
        [0.996, 0.082, 0.844, 0.772, 1.32, 1.00, 0.438, 0.991, 0.013, 0.997],
        [1.000, 0.011, 0.867, 0.813, 0.87, 1.00, 0.820, 0.965, 0.315, 0.992],
        [1.000, 0.069, 0.855, 0.654, 3.92, 0.99, 1.125, 0.894, 1.142, 0.984],
        [1.000, 0.043, 0.834, 0.612, 1.18, 1.00, 1.135, 0.862, 0.165, 0.999],
    ],
])

col_idxs = [1, 2, 3, 4, 8]
gla_scaling_performance[:, :, 2:4] = (1 - gla_scaling_performance[:, :, 2:4]) * 100
gla_scaling_performance = gla_scaling_performance[:, :, col_idxs]
training_sizes = np.asarray([1024, 2048, 4096, 8192, 16384, 32768])
training_sizes = np.ceil(training_sizes / 1024).astype(int)
titles = [
    'Optimal Power Flow', 'Fault Classification', 'Fault Localization', 'State Estimation', 'LMP Prediction']
topologies = ['ieee14', 'ieee39', 'ieee57', 'ieee118', 'ieee300']
metrics = [
    'Optimality Gap', 'Error Rate(%)', 'Error Rate(%)', 'NRMSE(%)', 'NRMSE(%)']
plt.figure(figsize=(20, 20))
for i in range(len(titles)):
    plt.subplot(2, 3, i + 1)
    log_x = np.log10(training_sizes)
    for j in range(len(topologies)):
        log_y = np.log10(gla_scaling_performance[:, j, i])
        plt.scatter(training_sizes, gla_scaling_performance[:, j, i])
        slope, intercept, r_value, p_value, std_err = linregress(log_x, log_y)
        t_val = t.ppf(0.975, len(training_sizes) - 2)
        ci = t_val * std_err * np.sqrt(1 / len(training_sizes) + (training_sizes - np.mean(training_sizes)) ** 2 / np.sum(
            (training_sizes - np.mean(training_sizes)) ** 2))
        y_fit = slope * log_x + intercept
        plt.plot(training_sizes, 10 ** y_fit, '--',
                 label=f'{topologies[j]}: a={10 ** intercept:.3f}, k={slope:.3f}, r = {r_value:.3f}')
    plt.xscale('log')
    plt.yscale('log')
    plt.xticks(training_sizes, labels=[str(x) for x in training_sizes.tolist()])
    from matplotlib.ticker import NullLocator

    plt.gca().xaxis.set_minor_locator(NullLocator())
    plt.title(f'{titles[i]}')
    plt.xlabel('Training Data Size ($\cdot$ 1024)')
    plt.ylabel(f'{metrics[i]}')
    plt.grid()
    plt.legend()

# plt.show()
plt.savefig('./figs/OpenGLA_training_proportion.png')
# tasks: opf, fault detection, state estimation, transient prediction, lmp prediction
# metrics: pass_rate, optimality gap, type_acc, loc_acc, nrmse, r^2, nrmse, r^2, nrmse, r^2
gla_mot_performance = [
    [1.000, 0.014, 0.963, 0.945, 1.10, 1.000, 0.339, 0.991, 0.012, 1.000],
    [0.996, 0.082, 0.844, 0.772, 1.32, 1.000, 0.438, 0.991, 0.013, 0.997],
    [1.000, 0.011, 0.867, 0.813, 0.87, 1.000, 0.820, 0.965, 0.315, 0.992],
    [1.000, 0.069, 0.855, 0.654, 3.92, 0.991, 1.125, 0.894, 1.142, 0.984],
    [1.000, 0.043, 0.834, 0.612, 1.18, 1.000, 1.135, 0.862, 0.165, 0.999],
    [0.453, 1.764, 0.661, 0.091, 4.20, 0.978, 0.000, 0.000, 4.162, 0.949]
]

gla_medium_performance = [  # opf, transient TBD
    [1.000, 0.014, 0.961, 0.943, 0.491, 1.000, 0.339, 0.991, 0.001, 1.000],
    [0.995, 0.082, 0.844, 0.781, 1.752, 1.000, 0.438, 0.991, 0.001, 1.000],
    [1.000, 0.011, 0.892, 0.824, 0.673, 1.000, 0.820, 0.965, 0.205, 0.992],
    [1.000, 0.069, 0.831, 0.685, 3.765, 0.992, 1.125, 0.894, 1.094, 0.984],
    [1.000, 0.023, 0.825, 0.658, 0.663, 1.000, 1.135, 0.862, 0.118, 0.999],
    [0.521, 1.443, 0.693, 0.139, 3.971, 0.981, 0.000, 0.000, 3.874, 0.966]
]

gla_large_performance = [
    [1.000, 0.001, 0.969, 0.946, 0.412, 1.000, 0.339, 0.991, 0.000, 1.000],
    [1.000, 0.009, 0.861, 0.827, 1.451, 1.000, 0.438, 0.991, 0.000, 1.000],
    [1.000, 0.005, 0.902, 0.823, 0.656, 1.000, 0.820, 0.965, 0.185, 0.992],
    [1.000, 0.029, 0.853, 0.714, 3.184, 0.993, 1.125, 0.894, 1.034, 0.984],
    [1.000, 0.009, 0.824, 0.681, 0.571, 1.000, 1.135, 0.862, 0.118, 0.999],
    [0.583, 1.224, 0.691, 0.153, 3.773, 0.984, 0.000, 0.000, 3.265, 0.974]
]

llm_performance = [
    [0.993, 0.018, 0.641, 0.602, 1.672, 0.994, 4.492, 0.443, 0.102, 0.992],
    [0.971, 0.163, 0.613, 0.573, 3.013, 0.992, 6.894, 0.667, 0.361, 0.983],
    [0.992, 0.024, 0.615, 0.594, 1.781, 0.993, 7.672, 0.815, 0.623, 0.954],
    [0.884, 0.362, 0.537, 0.422, 9.134, 0.973, 9.531, 0.681, 1.785, 0.912],
    [0.802, 0.761, 0.466, 0.411, 1.791, 0.992, 0.000, 0.000, 0.417, 0.981]
]

gcn_performance = [
    [0.982, 0.031, 0.772, 0.456, 2.351, 0.986, 3.331, 0.901, 0.071, 0.981],
    [0.981, 0.142, 0.711, 0.415, 3.222, 0.962, 3.502, 0.852, 0.452, 0.965],
    [0.983, 0.031, 0.802, 0.282, 4.925, 0.931, 4.736, 0.884, 0.394, 0.934],
    [0.874, 0.362, 0.463, 0.211, 8.641, 0.937, 7.113, 0.732, 1.976, 0.921],
    [0.755, 0.981, 0.374, 0.192, 4.687, 0.911, 9.756, 0.671, 0.487, 0.982],
    [0.083, 3.493, 0.255, 0.026, 18.642, 0.872, 0.000, 0.000, 5.327, 0.942]
]

gat_performance = np.asarray([
    [1.799, 0.992, 0.091, 0.983],
    [2.213, 0.971, 0.371, 0.975],
    [4.314, 0.944, 0.432, 0.947],
    [7.639, 0.951, 1.677, 0.945],
    [2.932, 0.952, 0.375, 0.987],
    [14.543, 0.903, 5.452, 0.941],
])

gin_performance = np.asarray([
    [1.853, 0.992, 0.095, 0.981],
    [2.078, 0.973, 0.392, 0.962],
    [4.679, 0.941, 0.494, 0.944],
    [8.034, 0.945, 1.563, 0.952],
    [2.532, 0.963, 0.272, 0.981],
    [11.458, 0.922, 4.781, 0.956],
])

deepopf_performance = np.asarray([
    [0.991, 0.116],
    [0.944, 0.294],
    [0.956, 0.197],
    [0.913, 0.113],
    [0.911, 0.391],
    [0.266, 2.526],
])

canos_performance = np.asarray([
    [1.000, 0.011],
    [1.000, 0.067],
    [1.000, 0.055],
    [0.925, 0.154],
    [0.974, 0.292],
    [0.346, 1.989]
])

patchtst_performance = np.asarray([
    [0.813, 0.964, 0.382, 0.998],
    [0.842, 0.891, 0.234, 0.998],
    [0.811, 0.896, 1.071, 0.974],
    [0.795, 0.767, 0.065, 0.999],
    [0.794, 0.635, 0.303, 0.991],
    [0.771, 0.063,  0.000, 0.000]
])

timesnet_performance = np.asarray([
    [0.769, 0.813, 3.813, 0.997],
    [0.747, 0.722, 3.861, 0.998],
    [0.766, 0.726, 27.074, 0.926],
    [0.654, 0.455, 3.723, 0.998],
    [0.632, 0.421, 11.132, 0.987],
    [0.601, 0.003, 0.000, 0.000]
])

selected_cols = [1, 2, 3, 4, 8]
gla_mot_performance = np.asarray(gla_mot_performance)
gla_medium_performance = np.asarray(gla_medium_performance)
gla_large_performance = np.asarray(gla_large_performance)
data = np.stack([gla_mot_performance, gla_medium_performance, gla_large_performance], axis=0)
data = data[:, :, selected_cols]


fig, axes = plt.subplots(nrows=1, ncols=data.shape[2], figsize=(20, 4), sharex=False, sharey=False)

tasks = ['OPF', 'Fault Type', 'Fault Loc', 'State Estimation', 'LMP Pred']
metrics = ['optimality gap', 'error rate(%)', 'error rate(%)', 'NRMSE(%)', 'NRMSE(%)']
envs = ['Texas2000']
baseline_names = {
    'OPF': ['OpenGLA', 'OpenGLA-M', 'OpenGLA-L'],
    'Fault Type': ['OpenGLA', 'OpenGLA-M', 'OpenGLA-L'],
    'Fault Loc': ['OpenGLA', 'OpenGLA-M', 'OpenGLA-L'],
    'State Estimation': ['OpenGLA', 'OpenGLA-M', 'OpenGLA-L'],
    'LMP Pred': ['OpenGLA', 'OpenGLA-M', 'OpenGLA-L'],
}

colors = sns.dark_palette("#4E79A7", n_colors=5, reverse=True)

for i, task in enumerate(tasks):
    ax = axes[i]
    bar_width = 0.8
    x = np.arange(len(baseline_names[task]))  # Initial positions of the bars
    values = data[:, :, i] if task not in ['Fault Loc', 'Fault Type'] else (1 - data[:, :, i])*100

    # Adjusting the positions of the bars to make them closer together within each subgraph
    for idx, baseline in enumerate(baseline_names[task]):
        value = values[idx].mean()
        ax.bar(x[idx], value, width=bar_width, color=colors[idx])

    ax.set_title(f'{task} - {metrics[i]}')
    ax.set_xticks(x)
    ax.set_xticklabels(baseline_names[task], fontsize=12, rotation=20)
    ax.tick_params(axis='y', labelsize=12)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3, fontsize=12)

plt.tight_layout(pad=2.0, rect=[0, 0.05, 1, 0.95])
# plt.show()
plt.savefig('./figs/OpenGLA_parameter_scaling.png')



selected_cols = [1, 2, 3, 4, 8]
gcn_performance = np.asarray(gcn_performance)
gla_mot_performance = np.asarray(gla_mot_performance)
llm_performance = np.asarray(llm_performance)
data = np.stack([gcn_performance, np.zeros_like(gcn_performance), np.zeros_like(gcn_performance), gla_mot_performance], axis=0)
data = data[:, :, selected_cols]


tasks = ['OPF', 'Fault Type', 'Fault Loc', 'State Estimation', 'LMP Pred']
metrics = ['optimality gap', 'error rate(%)', 'error rate(%)', 'NRMSE(%)', 'NRMSE(%)']
envs = ['Texas2000']
baseline_names = {
    'OPF': ['GCN', 'DeepOPF', 'CANOS', 'OpenGLA'],
    'Fault Type': ['Informer', 'PatchTst', 'TimesNet', 'OpenGLA'],
    'Fault Loc': ['Informer', 'PatchTst', 'TimesNet', 'OpenGLA'],
    'State Estimation': ['GCN', 'GAT', 'GIN', 'OpenGLA'],
    'LMP Pred': ['GCN', 'GAT', 'GIN', 'OpenGLA']
}

colors = ["#C9D1D9", "#A8B1BA", "#8B949E", '#4E79A7']

fig, axes = plt.subplots(nrows=1, ncols=data.shape[2], figsize=(20, 4), sharex=False, sharey=False)

for i, task in enumerate(tasks):
    for j, env in enumerate(envs):
        ax = axes[i]
        bar_width = 0.8
        x = np.arange(len(baseline_names[task])) #* bar_width  # Initial positions of the bars
        values = data[:, -1, i] if task not in ['Fault Loc', 'Fault Type'] else (1 - data[:, -1, i])*100

        # Adjusting the positions of the bars to make them closer together within each subgraph
        for idx, baseline in enumerate(baseline_names[task]):
            value = values[idx]
            if task == 'OPF':
                if baseline == 'DeepOPF':
                    value = deepopf_performance[-1, 1]
                if baseline == 'CANOS':
                    value = canos_performance[-1, 1]
            if task == 'Fault Loc':
                if baseline == 'PatchTst':
                    value = (1 - patchtst_performance[-1, 1])*100
                if baseline == 'TimesNet':
                    value = (1 - timesnet_performance[-1, 1])*100
            if task == 'Fault Type':
                if baseline == 'PatchTst':
                    value = (1 - patchtst_performance[-1, 0])*100
                if baseline == 'TimesNet':
                    value = (1-timesnet_performance[-1, 0])*100
            if task == 'State Estimation':
                if baseline == 'GAT':
                    value = gat_performance[-1, 0]
                if baseline == 'GIN':
                    value = gin_performance[-1, 0]
            if task == 'LMP Pred':
                if baseline == 'GAT':
                    value = gat_performance[-1, 2]
                if baseline == 'GIN':
                    value = gin_performance[-1, 2]
            ax.bar(x[idx], value, width=bar_width, color=colors[idx])

        if j == 0:
            ax.set_title(f'{task} - {metrics[i]}')
        if i == 0:
            ax.set_ylabel(env)
        ax.set_xticks(x)
        ax.set_xticklabels(baseline_names[task], fontsize=12, rotation=20)
        ax.tick_params(axis='y', labelsize=12)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3, fontsize=12)

plt.tight_layout(pad=2.0, rect=[0, 0.05, 1, 0.95])
# plt.show()
plt.savefig('./figs/OpenGLA_Performance_TX2000.png')

data = np.stack([gcn_performance[:-1], np.zeros_like(llm_performance), np.zeros_like(llm_performance), llm_performance, gla_mot_performance[:-1]], axis=0)
data = data[:, :, selected_cols]
tasks = ['OPF', 'Fault Type', 'Fault Loc', 'State Estimation', 'LMP Pred']
metrics = ['optimality gap', 'error rate(%)', 'error rate(%)', 'NRMSE(%)', 'NRMSE(%)' ]
envs = ['IEEE14', 'IEEE39', 'IEEE57', 'IEEE118', 'IEEE300']
baseline_names = {
    'OPF': ['GCN', 'DeepOPF', 'CANOS', 'Llama', 'OpenGLA'],
    'Fault Type': ['Informer', 'PatchTst', 'TimesNet', 'Llama', 'OpenGLA'],
    'Fault Loc': ['Informer', 'PatchTst', 'TimesNet', 'Llama', 'OpenGLA'],
    'State Estimation': ['GCN', 'GAT', 'GIN', 'Llama', 'OpenGLA'],
    'LMP Pred': ['GCN', 'GAT', 'GIN', 'Llama', 'OpenGLA']
}
colors = ["#C9D1D9", "#A8B1BA", "#8B949E", "#6E7781", '#4E79A7']

fig, axes = plt.subplots(nrows=data.shape[1], ncols=data.shape[2], figsize=(20, 20), sharex=False, sharey=False)

for i, task in enumerate(tasks):
    for j, env in enumerate(envs):
        ax = axes[j, i]
        bar_width = 0.8
        x = np.arange(len(baseline_names[task])) #* bar_width  # Initial positions of the bars
        values = data[:, j, i] if task not in ['Fault Loc', 'Fault Type'] else (1 - data[:, j, i])*100

        # Adjusting the positions of the bars to make them closer together within each subgraph
        for idx, baseline in enumerate(baseline_names[task]):
            value = values[idx]
            if task == 'OPF':
                if baseline == 'DeepOPF':
                    value = deepopf_performance[j, 1]
                if baseline == 'CANOS':
                    value = canos_performance[j, 1]
            if task == 'Fault Loc':
                if baseline == 'PatchTst':
                    value = (1-patchtst_performance[j, 1])*100
                if baseline == 'TimesNet':
                    value = (1-timesnet_performance[j, 1])*100
            if task == 'Fault Type':
                if baseline == 'PatchTst':
                    value = (1-patchtst_performance[j, 0])*100
                if baseline == 'TimesNet':
                    value = (1-timesnet_performance[j, 0])*100
            if task == 'State Estimation':
                if baseline == 'GAT':
                    value = gat_performance[j, 0]
                if baseline == 'GIN':
                    value = gin_performance[j, 0]
            if task == 'LMP Pred':
                if baseline == 'GAT':
                    value = gat_performance[j, 2]
                if baseline == 'GIN':
                    value = gin_performance[j, 2]

            ax.bar(x[idx], value, width=bar_width, color=colors[idx])

        if j == 0:
            ax.set_title(f'{task} - {metrics[i]}')
        if i == 0:
            ax.set_ylabel(env)
        ax.set_xticks(x)
        if j == len(envs)-1:
            ax.set_xticklabels(baseline_names[task], fontsize=12, rotation=20)
        else:
            ax.tick_params(axis='x', labelbottom=False)
        ax.tick_params(axis='y', labelsize=12)

handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3, fontsize=12)

plt.tight_layout(pad=2.0, rect=[0, 0.05, 1, 0.95])
# plt.show()
plt.savefig(f'./figs/OpenGLA_General_Performance.png')