import json
import matplotlib.pyplot as plt
import pandas as pd
from os.path import exists

import os
import numpy as np 
from matplotlib.lines import Line2D
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

model = "rset"
method = "cp-sat"
rule_pairs = [["closest", "farthest"], ["densest", "sparsest"], ["increment", "random"], ["farthest", "increment"], ["increment", "closest", "farthest"]]
figures_sizes = (8, 8)
subplot_data = {}

for dataset in ["adult", "bank-marketing", "compas", "default_credit", "diabetes", "fico"]:
    print("==== EXPERIMENT: " + str(dataset) + " ====")
    # Experiment (locate the right folder)
    folder = "results"

    # Combinations of hyperparameters
    val_trees = list(range(1, 50))
    val_trees.extend([50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 120, 125, 130, 135, 140, 145, 150, 175, 200])
    val_depths = ['4']
    val_seed = [i for i in range(0,5)]
    rule_list = ["densest", "sparsest", "closest", "increment", "farthest", "random"]

    params_list = []

    # Random generation of colors for the plots
    random_colors = {}
    cmap = plt.get_cmap('tab10')  
    colors_list = []
    n = len(val_depths) + len(val_trees) + 1
    colors_list = [cmap(i % cmap.N) for i in range(n)]
    i = 0
    for rule in rule_list:
        random_colors[rule] = colors_list[i]
        i += 1
    random_colors["random_baseline"] = colors_list[i]

    for rule in rule_list:
        for t in val_trees:
            for d in val_depths:
                params_list.append([rule, t, d])
    expe_suffix = dataset
    
    def get_single_file(params_list, seed, show=False):
        results_dict = {}
        missing_cnt = 0
        unknown_cnt = 0
        longest_time_res = -1

        for params in params_list:
            rule = params[0]
            n_estimators = params[1]
            max_depth = params[2]

            if dataset == "default_credit" and seed == 3:
                seed = 6

            filename = str(dataset) + "_" + str(n_estimators) +"_"+ str(max_depth)+ "_" + str(seed) + "_" + str(rule)
            path_to_file = f"./{folder}/{filename}.json"

            file_exists = exists(path_to_file)
            if file_exists:
                f = open(path_to_file)
                try:
                    data = json.load(f)
                except json.JSONDecodeError:
                    print("Error decoding JSON for file: ", path_to_file)
                    exit()

                mean_error = data["values"]["mean-error"]
                solve_duration = data["values"]["solve_duration_time"]
                solve_status = data["values"]["solve_status"]
                random_baseline = data["values"]["random_error"]
                best_adv_accuracy = data["values"]["best_adv_accuracy"]

                if solve_duration > longest_time_res:
                    longest_time_res = solve_duration

                if solve_status != "OPTIMAL":
                    print("Oups, for this configuration the solver struggled: ", str(n_estimators) + " trees, max. depth " + str(max_depth) + ", seed " + str(seed) + "(time elapsed: " + str(solve_duration) + ")" + "(status: " + str(solve_status) + ")")
                    unknown_cnt+=1
                else:
                    if not rule in results_dict.keys():
                        results_dict[rule] = {}
                    if not n_estimators in results_dict[rule].keys():
                        results_dict[rule][n_estimators] = {}
                    results_dict[rule][n_estimators][max_depth] = (mean_error, best_adv_accuracy)
                
                if "random_baseline" in results_dict.keys():
                    results_dict["random_baseline"][n_estimators] = {}
                    results_dict["random_baseline"][n_estimators]["4"] = random_baseline
                    assert(random_baseline == results_dict["random_baseline"][n_estimators]["4"])
                else:
                    results_dict["random_baseline"] = {}
                    results_dict["random_baseline"][n_estimators] = {}
                    results_dict["random_baseline"][n_estimators]["4"] = random_baseline
            else :
                print("missing file %s" %path_to_file)
                missing_cnt +=1
        print("missing %d files" %missing_cnt)
        print("ignored %d UNKNOWN runs" %unknown_cnt)
        print("longest run:", longest_time_res, " seconds")
        return results_dict

    all_seeds_results = []
    # First get per-fold results
    for seed in val_seed:
        local_results = get_single_file(params_list, seed)
        all_seeds_results.append(local_results)

    # Then compute averages
    average_results = {}
    std_results = {}
    adv_average_results = {}
    adv_std_results = {}
    for rule in all_seeds_results[0].keys(): # iterate over each curve (random + different depth values)
        rule_errors_list_avg = []
        rule_errors_list_std = []
        adv_errors_list_avg = []
        adv_errors_list_std = []
        for n_trees in val_trees: # iterate over n_estimators (x axis, i.e., #trees)
            acc_results_local = []
            acc_adv_results_local = []
            for one_seed_results in all_seeds_results:
                try:
                    if rule == "random_baseline":
                        acc_results_local.append(one_seed_results[rule][n_trees]['4'])
                    else: 
                        acc_results_local.append(one_seed_results[rule][n_trees]['4'][0])
                        acc_adv_results_local.append(one_seed_results[rule][n_trees]['4'][1])
                except KeyError:
                    continue
            rule_errors_list_avg.append(np.average(acc_results_local))
            rule_errors_list_std.append(np.std(acc_results_local))
            adv_errors_list_avg.append(np.average(acc_adv_results_local))
            adv_errors_list_std.append(np.std(acc_adv_results_local))

        average_results[rule] = rule_errors_list_avg
        std_results[rule] = rule_errors_list_std
        adv_average_results[rule] = adv_errors_list_avg
        adv_std_results[rule] = adv_errors_list_std

    # Accuracy plot
    for rule in all_seeds_results[0].keys():
        if rule == "random_baseline":
            continue
        val_trees_local = val_trees
        if len(average_results[rule]) < len(val_trees):
            last_index = (len(val_trees)-len(average_results[rule]))
            print("depth " + str(rule) + " diff is " + str(last_index))
            val_trees_local = val_trees[:-last_index]
        
        plt.figure(figsize=figures_sizes)
        plt.plot(val_trees_local, average_results[rule],c=random_colors[rule], label=rule)
        plt.fill_between(val_trees_local, np.asarray(average_results[rule]) - np.asarray(std_results[rule]), np.asarray(average_results[rule]) + np.asarray(std_results[rule]), color=random_colors[rule], alpha=0.2)
        
        plt.plot(val_trees_local, average_results["random_baseline"],c=random_colors["random_baseline"], label='random_baseline') #label='max depth'+one_depth_val+"(average & std)",
        plt.fill_between(val_trees_local, np.asarray(average_results["random_baseline"]) - np.asarray(std_results["random_baseline"]), np.asarray(average_results["random_baseline"]) + np.asarray(std_results["random_baseline"]), color=random_colors["random_baseline"], alpha=0.2)

        plt.xlabel("#trees")
        plt.ylabel("Reconstruction Error")
        plt.legend(loc='best')
        plt.title(f"{dataset} - {rule} Reconstruction Error vs Number of Trees")
        os.makedirs(f"./figures", exist_ok=True)
        plt.savefig(f'./figures/{expe_suffix}_{rule}_reconstruction_error.png', bbox_inches='tight')
        plt.close()

    for rules in rule_pairs:
        plt.figure(figsize=figures_sizes)
        for rule in rules + ["random_baseline"]:
            val_trees_local = val_trees
            if len(average_results[rule]) < len(val_trees):
                last_index = (len(val_trees)-len(average_results[rule]))
                print("depth " + str(rule) + " diff is " + str(last_index))
                val_trees_local = val_trees[:-last_index]
            plt.plot(val_trees_local, average_results[rule],c=random_colors[rule], label=rule)
            plt.fill_between(val_trees_local, np.asarray(average_results[rule]) - np.asarray(std_results[rule]), np.asarray(average_results[rule]) + np.asarray(std_results[rule]), color=random_colors[rule], alpha=0.2)
        plt.xlabel("#trees")
        plt.ylabel("Reconstruction Error")
        plt.legend(loc='best')
        plt.title(f"{dataset} - {rules[0]} - {rules[1]} Reconstruction Error vs Number of Trees")
        plt.savefig(f'./figures/{expe_suffix}_{rules[0]}_{rules[1]}_reconstruction_error.png', bbox_inches='tight')
        plt.close()

    subplot_data[dataset] = {
        "average_results": average_results,
        "std_results": std_results,
        "adv_average_results": adv_average_results,
        "adv_std_results": adv_std_results,
    }
    # Scatter plots
    for rule in all_seeds_results[0].keys():
        if rule == "random_baseline":
            continue

        recon_err = average_results[rule]
        adv_acc = adv_average_results[rule]

        plt.figure(figsize=figures_sizes)
        cmap = plt.cm.viridis 
        n_estimators_values = val_trees[:len(recon_err)]  # align with available points
        sc = plt.scatter(recon_err, adv_acc, c=n_estimators_values, cmap=cmap)
        plt.colorbar(sc, label='Number of Trees')
        plt.xlabel("Reconstruction Error")
        plt.ylabel("Adversarial Accuracy")
        plt.title(f"{dataset} Reconstruction Error vs Adversarial Accuracy")
        plt.savefig(f'./figures/{expe_suffix}_{rule}_recon_err_vs_adv_acc.png', bbox_inches='tight')
        plt.close()

# # # Combined plots

# for pair in rule_pairs:
#     fig, axes = plt.subplots(2, 3, figsize=(24, 8))

#     for i, dataset in enumerate(sorted(subplot_data.keys())):
#         ax = axes[i // 3, i % 3]
#         data = subplot_data[dataset]
#         average_results = data["average_results"]
#         std_results = data["std_results"]
#         for rule in pair + ["random_baseline"]:
#             val_trees_local = val_trees
#             if len(average_results[rule]) < len(val_trees):
#                 last_index = (len(val_trees) - len(average_results[rule]))
#                 print("depth " + str(rule) + " diff is " + str(last_index))
#                 val_trees_local = val_trees[:-last_index]

#             ax.plot(val_trees_local, average_results[rule], c=random_colors[rule], label=rule)
#             ax.fill_between(
#                 val_trees_local,
#                 np.asarray(average_results[rule]) - np.asarray(std_results[rule]),
#                 np.asarray(average_results[rule]) + np.asarray(std_results[rule]),
#                 color=random_colors[rule], alpha=0.2
#             )
#             ax.set_title(f"{dataset}", fontsize=28)
#             ax.grid(True)
#             ax.tick_params(axis='both', labelsize=24)

#     # # Shared legend and axis labels
#     handles, labels = axes[0, 0].get_legend_handles_labels()
#     fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.08), ncol=len(pair) + 1, fontsize=28, frameon=False)

#     fig.text(0.5, 0.04, 'Number of Trees', ha='center', fontsize=28)
#     fig.text(0.04, 0.5, 'Reconstruction Error', va='center', rotation='vertical', fontsize=28)

#     plt.tight_layout(rect=[0.05, 0.075, 1, 0.95])
#     plt.savefig(f"./figures/combined/combined_{pair[0]}_{pair[1]}_line.png", dpi=300, bbox_inches="tight")
#     print(f"Saving to ./figures/combined/combined_{pair[0]}_{pair[1]}_line.png ")
#     plt.close()

# #region SCATTER ON INCREMENT
# fig, axes = plt.subplots(2, 3, figsize=(24, 10), sharex=True, sharey=False)
# rule = "increment"
# cmap = plt.cm.viridis

# # Use consistent color values
# n_estimators_values = val_trees[:len(next(iter(subplot_data.values()))["average_results"][rule])]

# for i, dataset in enumerate(sorted(subplot_data.keys())):
#     ax = axes[i // 3, i % 3]
#     data = subplot_data[dataset]
#     recon_err = data["average_results"][rule]
#     adv_acc = data["adv_average_results"][rule]

#     sc = ax.scatter(recon_err, adv_acc, c=n_estimators_values, cmap=cmap, s=100)
#     ax.set_title(f"{dataset}", fontsize=28)
#     ax.grid(True)
#     ax.tick_params(axis='both', labelsize=24)

# # Tighter overall layout with reduced right space
# plt.subplots_adjust(left=0.05, right=0.9, bottom=0.15, top=0.92, wspace=0.25)

# # Move colorbar closer (reduce right margin)
# cbar_ax = fig.add_axes([0.92, 0.2, 0.015, 0.6])
# cbar = fig.colorbar(sc, cax=cbar_ax)
# cbar.ax.tick_params(labelsize=24)
# cbar.set_label('Number of Trees', fontsize=28)

# # Move fig.text labels further outward
# fig.text(0.5, 0.035, 'Reconstruction Error', ha='center', fontsize=28)
# fig.text(0.005, 0.5, 'Adversarial Accuracy', va='center', rotation='vertical', fontsize=28)

# # Save
# plt.savefig(f"./figures/combined/combined_scatter_full.png", dpi=300)
# plt.close()
# #endregion


# "densest", "sparsest"
# rules_for_combined_scatter = ["increment", "sparsest", "densest"]
rules_for_combined_scatter = ["increment", "closest", "farthest"]

fig, axes = plt.subplots(2, 3, figsize=(24, 10), sharex=True, sharey=False)
rule = "increment"
cmap = "magma"
markers = ["+", "*", "^"]
sc = None
for i, dataset in enumerate(sorted(subplot_data.keys())):
    ax = axes[i // 3, i % 3]
    data = subplot_data[dataset]

    for rule_idx, rule in enumerate(rules_for_combined_scatter):
        recon_err = data["average_results"][rule]
        adv_acc = data["adv_average_results"][rule]
        
        st = ax.scatter(recon_err, adv_acc, c=n_estimators_values, cmap=cmap, s=85, label=rule, alpha=0.8, marker=markers[rule_idx])
        if sc is None:
            sc = st

    ax.set_title(f"{dataset}", fontsize=28)
    ax.grid(True)
    ax.tick_params(axis='both', labelsize=24)

# Tighter overall layout with reduced right space
plt.subplots_adjust(left=0.09, right=0.9, bottom=0.15, top=0.92, wspace=0.25)

# Move colorbar closer (reduce right margin)
# cbar_ax = fig.add_axes([0.92, 0.2, 0.015, 0.6])
cbar_ax = fig.add_axes([0.92, 0.2, 0.015, 0.6])
cbar = fig.colorbar(sc, cax=cbar_ax)
cbar.ax.tick_params(labelsize=24)
cbar.set_label('Number of Trees', fontsize=28)

# # Shared legend and axis labels
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.08), ncol=3, fontsize=28, frameon=False)

fig.text(0.5, 0.04, 'Number of Trees', ha='center', fontsize=28)
fig.text(0.04, 0.5, 'Reconstruction Error', va='center', rotation='vertical', fontsize=28)

plt.tight_layout(rect=[0.05, 0.075, 0.03, 0.95])
plt.savefig(f"./figures/combined/combined_stack_scatter.png", dpi=300, bbox_inches="tight")
print(f"Saving to ./figures/combined/combined_stack_scatter.png ")
plt.close()
