import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def get_sweep_results(api, sweep_id, models, project="emalgorithm/network_games"):
    sweep = api.sweep(f"{project}/{sweep_id}")
    runs = sweep.runs
    result_list = []
    columns = ["target_spectral_radius", "graph_type", "game_type", "alpha", "n_games", "n_nodes"]
    result_keys = [f"{model}_test_roc_auc_mean" for model in models] + [f"{model}_test_roc_auc_std" for model in models]
    
    for run in runs: 
        if run.state == "finished":   
                run_dict = {k: v for k,v in run.config.items() if k in columns}
                results = run.history(keys=result_keys)

                for model in models:
                    run_dict[f"{model}_test_roc_auc_mean"] = results[f"{model}_test_roc_auc_mean"].to_numpy()[-1]
                    run_dict[f"{model}_test_roc_auc_std"] = results[f"{model}_test_roc_auc_std"].to_numpy()[-1]
                    run_dict[f"run_name"] = run.name
                
                result_list.append(run_dict)

    return pd.DataFrame.from_records(result_list) 

def plot_linear_quadratic(df, models, colors, names, graph_names, title, ylims, file_name, alphas=[0.0, 0.5, 1.0]):
    fig, axs = plt.subplots(nrows=len(alphas), ncols=3, figsize=(24, 6 * len(alphas)))
    
    if axs.ndim == 1:
        axs = np.expand_dims(axs, axis=0)
        
#     alphas = sorted(df["alpha"].unique())
    graphs = ["watts_strogatz", "erdos_renyi", "barabasi_albert"]

    for i, alpha in enumerate(alphas):
        for j, graph in enumerate(graphs):
            df_tmp = df.loc[(df['graph_type'] == graph) & (df['alpha'] == alpha)].sort_values(by=["target_spectral_radius"])
            for model in models:
                if model == "lin_quad_opt" and alpha == 0.5:
                    continue
                
                model_test_roc_aucs_mean = df_tmp[f'{model}_test_roc_auc_mean'].to_numpy()
                model_test_roc_aucs_std = df_tmp[f'{model}_test_roc_auc_std'].to_numpy()
                spectral_radiuses = sorted(df_tmp["target_spectral_radius"].unique())

                axs[i, j].plot(spectral_radiuses, model_test_roc_aucs_mean, marker="o", label=names[model], color=colors[model])
                axs[i, j].fill_between(spectral_radiuses, model_test_roc_aucs_mean-model_test_roc_aucs_std, model_test_roc_aucs_mean+model_test_roc_aucs_std, alpha=0.2)

            axs[i, j].set_ylim(ylims[0], ylims[1])

            axs[i, j].set_title(f"{graph_names[graph]} graphs, α={alpha}", fontweight='bold', fontsize=17)
            axs[i, j].legend(prop={'weight':'bold'}, loc="lower right")
            axs[i, j].set_xlabel("Spectral Radius ρ(βA)", fontweight='bold', fontsize=13)
            axs[i, j].set_ylabel("Test ROC AUC", fontweight='bold', fontsize=13)

    fig.tight_layout(pad=3.5)
    fig.subplots_adjust(top=0.90)
    plt.show()
    fig.savefig(f'plots/{file_name}.pdf')

def plot_results_over_variable(df, models, colors, names, graph_names, game_names, game_types, variable, title, xlabel, ylims, file_name, print_game=False, top=0.80):
    nrows = len(game_types)
    fig, axs = plt.subplots(nrows=nrows, ncols=3, figsize=(24, 6 * nrows))
    
    if axs.ndim == 1:
        axs = np.expand_dims(axs, axis=0)

    alphas = sorted(df[variable].unique())
    graphs = ["watts_strogatz", "erdos_renyi", "barabasi_albert"] 

    for j, game in enumerate(game_types):
        for i, graph in enumerate(graphs):
            df_tmp = df.loc[(df['graph_type'] == graph) & (df['game_type'] == game)].sort_values(by=[variable])
            
            for model in models:
                model_test_roc_aucs_mean = df_tmp[f'{model}_test_roc_auc_mean'].to_numpy()
                model_test_roc_aucs_std = df_tmp[f'{model}_test_roc_auc_std'].to_numpy()

                axs[j, i].plot(alphas, model_test_roc_aucs_mean, marker="o", label=names[model], color=colors[model])
                axs[j, i].fill_between(alphas, model_test_roc_aucs_mean-model_test_roc_aucs_std, model_test_roc_aucs_mean+model_test_roc_aucs_std, alpha=0.2)

            axs[j, i].set_ylim(ylims[0], ylims[1])

            game_text = f'{game_names[game]} game, ' if print_game else ""
            axs[j, i].set_title(f"{game_text}{graph_names[graph]} graphs", fontweight='bold', fontsize=17)
            axs[j, i].legend(prop={'weight':'bold'})
            axs[j, i].set_xlabel(xlabel, fontweight='bold', fontsize=13)
            axs[j, i].set_ylabel("Test ROC AUC", fontweight='bold', fontsize=13)

    fig.tight_layout(pad=3.5)
    fig.subplots_adjust(top=top)
    plt.show()
    fig.savefig(f'plots/{file_name}.pdf')
    
def barchart_plot(df, models, colors, names, graph_names, graph_types, title, ylims, file_name):
    model_test_roc_auc_means, model_test_roc_auc_stds = [], []
    correlation_test_roc_auc_means, correlation_test_roc_auc_stds = [], []
    results = {model: {"means": [], "stds": []} for model in models}

    for graph_type in graph_types:
        df_tmp = df.loc[df["graph_type"] == graph_type]
        
        for model in models:
                model_test_roc_aucs_mean = df_tmp[f'{model}_test_roc_auc_mean'].to_numpy().round(decimals=3)[0]
                model_test_roc_aucs_std = df_tmp[f'{model}_test_roc_auc_std'].to_numpy().round(decimals=3)[0]
                
                results[model]["means"].append(model_test_roc_aucs_mean)
                results[model]["stds"].append(model_test_roc_aucs_std)

    x = np.arange(len(graph_types))  # the label locations
    n_models = len(models)
    width = 0.92 / n_models  # the width of the bars
    rs = [np.arange(len(graph_types)) + i * width for i in range(n_models)]

    fig, ax = plt.subplots()
    
    for i, model in enumerate(models):
        rects = ax.bar(rs[i], results[model]["means"], width, yerr=results[model]["stds"], align='center', alpha=0.7, ecolor='black', capsize=10, label=names[model], color=colors[model])
        ax.bar_label(rects, padding=3, fontsize=35 / n_models, fontweight='bold')

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('Test ROC AUC', fontweight='bold', fontsize=13)
    ax.set_xticks([r + 2 * width for r in range(len(graph_types))])
    ax.set_xticklabels(graph_types, fontweight='bold', fontsize=10)
    ax.legend(loc="upper left", prop={'weight':'bold'})

    ax.set_ylim(ylims[0], ylims[1])

#     fig.tight_layout()
#     fig.subplots_adjust(top=1)
    plt.show()
    fig.savefig(f'plots/{file_name}.pdf')