import os
import glob
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import itertools

# Directory containing your Diabetes results
diabetes_dir = 'pima_Diabetes_Result'

# Parameters to sweep
k_values = list(range(5, 26))
lambda_values = [i / 10 for i in range(1, 10)]

# Measures and algorithm offsets in the results array
measures = ['FJR Violation', 'Core Violation', 'Within-Cluster Distance', 'k-means Objective', 'k-medoids Objective']
algorithms = {
    'GC': 0,           # Greedy Capture starts at index 0
    'SemiBall': 5,     # Semi-Ball growing starts at index 5
    'k-means': 10,     # k-means baseline starts at index 10
    'k-medoids': 15    # k-medoids baseline starts at index 15
}

# Load all CSV results into dicts: (k, lambda) -> metrics array
results_ksweep = {}
results_lambdasweep = {}

for path in glob.glob(os.path.join(diabetes_dir, 'pima_k_sweep_k=*-lambda=0.5-Diabetes_summary_1.csv')):
    fname = os.path.basename(path)
    parts = fname.split('-')
    k = int(parts[0].split('=')[1])
    lam = float(parts[1].split('=')[1])
    data = np.loadtxt(path, delimiter=',')
    results_ksweep[(k, lam)] = (data[0], data[1])

for path in glob.glob(os.path.join(diabetes_dir, 'pima_lambda_sweep_k=15-lambda=*-Diabetes_summary_1.csv')):
    fname = os.path.basename(path)
    parts = fname.split('-')
    k = int(parts[0].split('=')[1])
    lam = float(parts[1].split('=')[1])
    data = np.loadtxt(path, delimiter=',')
    results_lambdasweep[(k, lam)] = (data[0], data[1])

# Ensure output directory
plot_dir = 'pima_Diabetes_Plots_final_pdf'
os.makedirs(plot_dir, exist_ok=True)


marker_cycle_all = itertools.cycle(['o', 's', '^', 'D'])  # 4 unique markers
color_cycle_all = itertools.cycle(['b', 'g', 'r', 'c'])   # 4 consistent colors

# Generate combined plots per measure
for mi, measure in enumerate(measures):
    # Plot vs λ for fixed k=15
    plt.figure()
    for alg_name, offset in algorithms.items():
        lambdas = []
        means_l = []
        cis_l = []
        for lam in lambda_values:
            key = (15, lam)
            if key not in results_lambdasweep:
                continue
            mean_vals, ci_vals = results_lambdasweep[key]
            lambdas.append(lam)
            means_l.append(mean_vals[offset + mi])
            cis_l.append(ci_vals[offset + mi])
        color = next(color_cycle_all)
        marker = next(marker_cycle_all)
        plt.errorbar(
            lambdas, means_l, yerr=None,
            marker=marker, color=color,
            linestyle='-', label=alg_name, alpha=0.85, linewidth=1.2,
            markerfacecolor='none'
        )
        plt.fill_between(
            lambdas,
            np.array(means_l) - np.array(cis_l),
            np.array(means_l) + np.array(cis_l),
            color=color, alpha=0.2
        )
    plt.xlabel('Weighted Loss Parameter $\\lambda$', fontsize=22)
    plt.ylabel(measure, fontsize=22)
    # plt.title(f'Diabetes: {measure} vs λ (k=15)')
    # plt.legend()
    # plt.grid(True)
    plt.minorticks_off()
    plt.grid(True, which='major', linestyle='-', linewidth=0.5)
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, f'{measure}_vs_lambda_k15_Diabetes.pdf'))
    plt.clf()

    # Plot vs k for fixed λ=0.5
    plt.figure()
    for alg_name, offset in algorithms.items():
        ks = []
        means_k = []
        cis_k = []
        for k in k_values:
            key = (k, 0.5)
            if key not in results_ksweep:
                continue
            mean_vals, ci_vals = results_ksweep[key]
            ks.append(k)
            means_k.append(mean_vals[offset + mi])
            cis_k.append(ci_vals[offset + mi])
        color = next(color_cycle_all)
        marker = next(marker_cycle_all)
        plt.errorbar(
            ks, means_k, yerr=None,
            marker=marker, color=color,
            linestyle='-', label=alg_name, alpha=0.85, linewidth=1.2,
            markerfacecolor='none'
        )
        plt.fill_between(
            ks,
            np.array(means_k) - np.array(cis_k),
            np.array(means_k) + np.array(cis_k),
            color=color, alpha=0.2
        )
    plt.xlabel('Number of Clusters $k$', fontsize=22)
    plt.ylabel(measure, fontsize=22)
    # plt.title(f'Diabetes: {measure} vs k (λ=0.5)')
    # plt.legend()
    # plt.grid(True)
    plt.minorticks_off()
    plt.grid(True, which='major', linestyle='-', linewidth=0.5)
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, f'{measure}_vs_k_lambda0.5_Diabetes.pdf'))
    plt.clf()

# --- Standalone legend image ---
import matplotlib.lines as mlines

legend_fig = plt.figure(figsize=(10, 2))
handles = []
labels = []

marker_cycle = itertools.cycle(['o', 's', '^', 'D'])
color_cycle = itertools.cycle(['b', 'g', 'r', 'c'])

for alg_name in algorithms.keys():
    marker = next(marker_cycle)
    color = next(color_cycle)
    line = mlines.Line2D([], [], color=color, marker=marker, linestyle='-',
                         label=alg_name, markerfacecolor='none', linewidth=1.5)
    handles.append(line)
    labels.append(alg_name)

legend_fig.legend(handles=handles, labels=labels, loc='center', ncol=len(labels), fontsize=18, frameon=False)
plt.axis('off')
plt.tight_layout()
legend_fig.savefig(os.path.join(plot_dir, 'Diabetes_Legend_Only.pdf'), bbox_inches='tight')
plt.close(legend_fig)