import os
import glob
import numpy as np
import matplotlib
matplotlib.use('Agg')          # non-interactive backend
import matplotlib.pyplot as plt
import itertools

# Where your Iris results live:
iris_dir = 'Iris_Results_1'      # or adjust to the actual folder

# Use k values from 5 through 25
k_list = list(range(5, 26))
lambda_list = np.round(np.linspace(0.1, 0.9, 9), 1)

# Prepare storage: Z[alg][i,j] = value for k_list[i], lambda_list[j]
# Algorithm → row index in CSV:
alg_rows = {
    'GC': 0,
    'SemiBall'     : 2,
    'k-means++'       : 4,
    'k-medoids'     : 6,
}

# Measure names → column index in each CSV row
measures = {
    'FJR Violation'               : 1,
    'Core Violation'              : 2,
    'Within-Cluster Distance'     : 3,
    'k-means Objective'   : 4,
    'k-medoids Objective' : 5,
}


# Make output dir
plot_dir = 'Iris_Plots_1_pdf'
os.makedirs(plot_dir, exist_ok=True)

marker_cycle_all = itertools.cycle(['o', 's', '^', 'D'])
color_cycle_all = itertools.cycle(['b', 'g', 'r', 'c'])


for mname, col in measures.items():
    # Plot vs k for fixed λ=0.5
    plt.figure()
    for alg_name, row in alg_rows.items():
        ks = []
        means = []
        cis = []
        for k in k_list:
            key = (k, 0.5)
            fname = f'k_sweep_k={k}-lambda=0.5-Irismax.csv'
            fpath = os.path.join(iris_dir, fname)
            if not os.path.exists(fpath):
                continue
            data = np.loadtxt(fpath, delimiter=',')
            ks.append(k)
            means.append(data[row, col])
            cis.append(data[row + 1, col])
        color = next(color_cycle_all)
        marker = next(marker_cycle_all)
        plt.errorbar(
            ks, means, yerr=None,
            marker=marker, color=color,
            linestyle='-', label=alg_name, alpha=0.85, linewidth=1.2,
            markerfacecolor='none'
        )
        plt.fill_between(
            ks,
            np.array(means) - np.array(cis),
            np.array(means) + np.array(cis),
            color=color, alpha=0.2
        )
    plt.xlabel('Number of Clusters $k$', fontsize=22)
    plt.ylabel(mname, fontsize=22)
    #plt.title(f'Iris_1: {mname} vs k (λ=0.5)')
    # plt.legend(fontsize=16)
    plt.minorticks_off()
    plt.grid(True, which='major', linestyle='-', linewidth=0.5)
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
    plt.grid(True)
    plt.tight_layout()
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.savefig(os.path.join(plot_dir, f'Iris1_{mname}_vs_k_lambda0.5.pdf'))
    plt.clf()

    # Plot vs lambda for fixed k=15
    plt.figure()
    for alg_name, row in alg_rows.items():
        lams = []
        means = []
        cis = []
        for lam in lambda_list:
            key = (15, lam)
            fname = f'lambda_sweep_k=15-lambda={lam}-Irismax.csv'
            fpath = os.path.join(iris_dir, fname)
            if not os.path.exists(fpath):
                continue
            data = np.loadtxt(fpath, delimiter=',')
            lams.append(lam)
            means.append(data[row, col])
            cis.append(data[row + 1, col])
        color = next(color_cycle_all)
        marker = next(marker_cycle_all)
        plt.errorbar(
            lams, means, yerr=None,
            marker=marker, color=color,
            linestyle='-', label=alg_name, alpha=0.85, linewidth=1.2,
            markerfacecolor='none'
        )
        plt.fill_between(
            lams,
            np.array(means) - np.array(cis),
            np.array(means) + np.array(cis),
            color=color, alpha=0.2
        )
    plt.xlabel('Weighted Loss Parameter $\\lambda$', fontsize=22)
    plt.ylabel(mname, fontsize=22)
    #plt.title(f'Iris_1: {mname} vs λ (k=15)')
    # plt.legend(fontsize=16)
    plt.minorticks_off()
    plt.grid(True, which='major', linestyle='-', linewidth=0.5)
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
    plt.grid(True)
    plt.tight_layout()
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.savefig(os.path.join(plot_dir, f'Iris1_{mname}_vs_lambda_k15.pdf'))
    plt.clf()

# --- Standalone legend figure ---
import matplotlib.lines as mlines

fig_legend = plt.figure(figsize=(10, 2))
handles = []
labels = []

marker_cycle = itertools.cycle(['o', 's', '^', 'D'])
color_cycle = itertools.cycle(['b', 'g', 'r', 'c'])

for alg_name in alg_rows.keys():
    marker = next(marker_cycle)
    color = next(color_cycle)
    line = mlines.Line2D([], [], color=color, marker=marker, linestyle='-',
                         label=alg_name, markerfacecolor='none', linewidth=1.5)
    handles.append(line)
    labels.append(alg_name)

fig_legend.legend(handles=handles, labels=labels, loc='center', ncol=len(labels), fontsize=22, frameon=False)
plt.axis('off')
plt.tight_layout()
fig_legend.savefig(os.path.join(plot_dir, 'Iris_Legend_Only.pdf'), bbox_inches='tight')
plt.close(fig_legend)
