import numpy as np
import cvxpy as cp
from sklearn.cluster import KMeans
import pandas as pd

def run_sdp_clustering(task_affinities, k, use_exp=True, temperature=1.0):
    """
    Run SDP clustering with spectral rounding + k-means.
    """
    if use_exp:
        task_affinities = np.exp(task_affinities / temperature)

    def sdp_clustering(T, k):
        n = T.shape[0]

        A = []
        b = []
        # constraint 1: trace(X) = k
        A.append(np.eye(n))
        b.append(k)

        # constraint 2: sum across rows = 1
        for i in range(n):
            tmp_A = np.zeros((n, n))
            tmp_A[:, i] = 1
            A.append(tmp_A)
            b.append(1)

        X = cp.Variable((n, n), symmetric=True)
        constraints = [X >> 0, X >= 0]
        constraints += [cp.trace(A[i] @ X) == b[i] for i in range(len(A))]

        prob = cp.Problem(cp.Minimize(cp.trace(T @ X)), constraints)
        prob.solve(solver=cp.SCS, verbose=False)  # can change solver if needed

        return X.value

    # SDP objective uses maximum - affinity
    maximum = np.max(task_affinities)
    X_value = sdp_clustering(maximum - task_affinities, k)

    # --- Spectral rounding ---
    eigvals, eigvecs = np.linalg.eigh(X_value)
    # take top-k eigenvectors
    U = eigvecs[:, -k:]

    # k-means on embedding
    kmeans = KMeans(n_clusters=k, n_init=10, random_state=0).fit(U)
    labels = kmeans.labels_

    # --- Collect results ---
    assignment = {}
    for idx, lbl in enumerate(labels):
        assignment.setdefault(lbl, []).append(idx)

    group_len = {}
    for cluster_idx, cluster_tasks in assignment.items():
        print(f"Cluster {cluster_idx}: {' '.join(map(str, cluster_tasks))}")
        if f'numtask_{len(cluster_tasks)}' in group_len:
            group_len[f'numtask_{len(cluster_tasks)}']+=1
        else:
            group_len[f'numtask_{len(cluster_tasks)}']=1

    print(f"Formed {len(assignment)} clusters with group len {group_len}")
    print(f'dataset: {dataset}, method: {method}, run: {run}, TASKS: {TASKS}')
    # # Save if paths are provided
    if dataset is not None and method is not None and run is not None and TASKS is not None:
        cluster_list = []
        for cluster_idx, clusters in assignment.items():
            task_group = [TASKS[idx] for idx in clusters]
            cluster_list.append({"cluster": task_group})
        df = pd.DataFrame(cluster_list)
        # print(df)
        df.to_csv(
            f'RESULTS/GROUP_SELECTION/Clustering/{method}/{method}_{dataset}_clusters_{k}_GRADTAG_run_{run}{suffix}.csv',
            index=False)

# %%
datasets = ['School', 'Chemical', 'Landmine']
TASKS_DICT = {'School': [i for i in range(1, 140)],
              'Landmine': [i for i in range(0, 29)],
              'Chemical': [2, 5, 6, 9, 10, 12, 18, 20, 22, 24, 25, 27, 28, 30, 46, 52, 55,
                           57, 59, 61, 67, 70, 76, 78, 80, 81, 83, 84, 85, 86, 87, 89, 90, 91, 92],
              }

for dataset in datasets:
    print(f'\n******* Dataset: {dataset} *******')
    if dataset == 'School':
        cluster_range = [3,4,5,6,7,8,9,10,15,20, 25, 30, 40, 50]
    else:
        cluster_range = [2, 3, 4, 5,6,7,8,9, 10, 15, 20]

    TASKS = TASKS_DICT[dataset]


    method = 'Ours'
    # method = 'GRADTAE'
    # method = 'ITA' #TAG
    # method = 'HOA'
    if method == 'Ours':
        RUNS = [1, 2, 3, 4, 5, 6]
        ModelName = 'Quadratic'
    if method == 'ITA':
        RUNS = [1, 2, 3]
    if method == 'GRADTAE':
        if dataset == 'School':
            RUNS = [1, 3, 4]
        else:
            RUNS = [1, 2, 3]
    if method == 'HOA':
        RUNS = [1, 3, 4, 5, 6]

    for run in RUNS:
        if method=='Ours':
            selected_seed = 2025
            suffix = f'_{selected_seed}'
        else:
            suffix = ''
        if method == 'ITA':
            if dataset == 'Chemical':
                datapath = '../mtl_training/chem_results/'
            if dataset == 'School':
                datapath = '../mtl_training/sch_results/'
            if dataset == 'Landmine':
                datapath = '../mtl_training/landmine_results/'

            pairwise_preds = pd.read_csv(f'{datapath}/ITA/{method}_matrix_run_{run}_FIXED.csv')
            pairwise_preds = np.array(pairwise_preds)
            np.fill_diagonal(pairwise_preds, 0.0)
        if method == 'Ours':
            pairwise_preds = np.load(f'../RESULTS/New_{ModelName}_{dataset}_Predicted_Pairwise_Affinity_{run}{suffix}.npy')
            print(f'shape of pairwise_preds: {pairwise_preds.shape}')

        if method == 'GRADTAE':
            # M = 5
            # grad_run = 2
            # pairwise_preds = pd.read_csv(
            #     f'../mtl_estimation/Results/Estimated_Task_Affinity_{dataset}_New_M_{M}_subsetNum_{run}.csv')
            pairwise_preds = pd.read_csv(f'../mtl_estimation/Results/Estimated_Task_Affinity_{dataset}_{run}.csv')

            pairwise_preds = np.array(pairwise_preds)
            print(f'shape of pairwise_preds: {pairwise_preds.shape}')

            # ✅ set diagonal entries to 1
            np.fill_diagonal(pairwise_preds, 0.0)
            print(f'shape of pairwise_preds: {pairwise_preds.shape}')
            print(method, pairwise_preds[0][:10])

        if method == 'HOA':
            # === Load prediction & pair files ===
            if run == 'AVG':
                pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_Avg_SGD_FIXED.csv')
            else:
                pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{run}_SGD_FIXED.csv')
            pairwise_preds = np.array(pairwise_affinity)
            np.fill_diagonal(pairwise_preds, 0.0)
            print(f'shape of pairwise_preds: {pairwise_preds.shape}')
            print(method, pairwise_preds[0][:10])

        for k in cluster_range:
            task_affinities = np.array(pairwise_preds)
            run_sdp_clustering(task_affinities, k, use_exp=False)
