import os
import glob
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import itertools

# Directory containing your Adults results
adults_dir = 'Adults_Result'

# Parameters to sweep
k_values = list(range(5, 26))
lambda_values = [i / 10 for i in range(1, 10)]

# Measures and algorithm offsets in the results array
measures = ['FJR Violation', 'Core Violation', 'Within-Cluster Distance', 'k-means Objective', 'k-medoids Objective']
algorithms = {
    'GC': 0,
    'SemiBall': 5,
    'k-means++': 10,
    'k-medoids': 15
}

# Load all CSV results into a dict: (k, lambda) -> metrics array
results = {}
for path in glob.glob(os.path.join(adults_dir, 'k_sweep_k=*-lambda=0.5-Adultsmax.csv')) + \
         glob.glob(os.path.join(adults_dir, 'lambda_sweep_k=15-lambda=*-Adultsmax.csv')):
    fname = os.path.basename(path)
    if fname.startswith('k_sweep'):
        parts = fname.split('-')
        k = int(parts[0].split('=')[1])
        lam = float(parts[1].split('=')[1])
    elif fname.startswith('lambda_sweep'):
        parts = fname.split('-')
        k = int(parts[0].split('=')[1])
        lam = float(parts[1].split('=')[1])
    else:
        continue
    data = np.loadtxt(path, delimiter=',')
    # data[0] is mean metrics, data[1] is 95% CI half-widths
    results[(k, lam)] = (data[0], data[1])

# Ensure output directory
plot_dir = 'Adults_Plots_2_pdf'
os.makedirs(plot_dir, exist_ok=True)

marker_cycle_all = itertools.cycle(['o', 's', '^', 'D'])  # 4 unique markers
color_cycle_all = itertools.cycle(['b', 'g', 'r', 'c'])   # 4 consistent colors

# Generate combined plots per measure
for mi, measure in enumerate(measures):
    # Plot vs λ for fixed k=15
    plt.figure()
    for alg_name, offset in algorithms.items():
        color = next(color_cycle_all)
        marker = next(marker_cycle_all)
        lambdas = []
        means_l = []
        cis_l = []
        for lam in lambda_values:
            key = (15, lam)
            if key not in results:
                continue
            mean_vals, ci_vals = results[key]
            lambdas.append(lam)
            means_l.append(mean_vals[offset + mi])
            cis_l.append(ci_vals[offset + mi])
        plt.errorbar(
            lambdas, means_l, yerr=None,
            marker=marker, color=color,
            linestyle='-', label=alg_name, alpha=0.85, linewidth=1.2,
            markerfacecolor='none'
        )
        plt.fill_between(
            lambdas,
            np.array(means_l) - np.array(cis_l),
            np.array(means_l) + np.array(cis_l),
            color=color, alpha=0.2
        )
    plt.xlabel('Weighted Loss Parameter $\\lambda$', fontsize=22)
    plt.ylabel(measure, fontsize=22)
    #plt.title(f'Adults: {measure} vs λ (k=15)')
    # plt.legend(fontsize=22, loc='upper right')
    plt.minorticks_off()
    plt.grid(True, which='major', linestyle='-', linewidth=0.5)
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
    plt.tight_layout()
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.savefig(os.path.join(plot_dir, f'{measure}_vs_lambda_k15_Adults.pdf'))
    plt.clf()

    # Plot vs k for fixed λ=0.5
    plt.figure()
    for alg_name, offset in algorithms.items():
        color = next(color_cycle_all)
        marker = next(marker_cycle_all)
        ks = []
        means_k = []
        cis_k = []
        for k in k_values:
            key = (k, 0.5)
            if key not in results:
                continue
            mean_vals, ci_vals = results[key]
            ks.append(k)
            means_k.append(mean_vals[offset + mi])
            cis_k.append(ci_vals[offset + mi])
        plt.errorbar(
            ks, means_k, yerr=None,
            marker=marker, color=color,
            linestyle='-', label=alg_name, alpha=0.85, linewidth=1.2,
            markerfacecolor='none'
        )
        plt.fill_between(
            ks,
            np.array(means_k) - np.array(cis_k),
            np.array(means_k) + np.array(cis_k),
            color=color, alpha=0.2
        )
    plt.xlabel('Number of Clusters $k$', fontsize=22)
    plt.ylabel(measure, fontsize=22)
    #plt.title(f'Adults: {measure} vs k (λ=0.5)')
    # plt.legend(fontsize=22, loc='upper right')
    plt.minorticks_off()
    plt.grid(True, which='major', linestyle='-', linewidth=0.5)
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
    plt.tight_layout()
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.savefig(os.path.join(plot_dir, f'{measure}_vs_k_lambda0.5_Adults.pdf'))
    plt.clf()

# Plot a separate legend
import matplotlib.lines as mlines

fig_legend = plt.figure(figsize=(12, 2))
handles = []
labels = []

marker_cycle = itertools.cycle(['o', 's', '^', 'D'])
color_cycle = itertools.cycle(['b', 'g', 'r', 'c'])

for alg_name in algorithms.keys():
    marker = next(marker_cycle)
    color = next(color_cycle)
    line = mlines.Line2D([], [], color=color, marker=marker, linestyle='-',
                         label=alg_name, markerfacecolor='none', linewidth=1.5)
    handles.append(line)
    labels.append(alg_name)

fig_legend.legend(handles=handles, labels=labels, loc='center', ncol=len(labels), fontsize=22, frameon=False)
plt.axis('off')
plt.tight_layout()
fig_legend.savefig(os.path.join(plot_dir, 'Adults_Legend_Only.pdf'))
plt.close(fig_legend)