import os
import pickle
import numpy as np
import igraph as ig
import leidenalg

# Mimi config
MATRIX_DIR = "/home/wmar/wmar_audio/outputs/confusion/matrices_new"
CLUSTERS_PKL = "/home/wmar/wmar_audio/models/embeddings/mimi_leiden_clusterings_trainonly_allparams.pkl"

CHANNELS_NAMES = ['rvq_first_0', 'rvq_rest_0', 'rvq_rest_1', 'rvq_rest_2',
                  'rvq_rest_3', 'rvq_rest_4', 'rvq_rest_5', 'rvq_rest_6']
COUNTS = [1, 5, 10, 25]
RESOLUTIONS = [0.2, 0.5, 0.8, 1.0, 1.2, 1.5]

# Encodec config
MATRIX_DIR = "/home/wmar/wmar_audio/outputs/confusion/matrices_encodec"
CLUSTERS_PKL = "/home/wmar/wmar_audio/models/embeddings/encodec_leiden_clusterings_trainonly_allparams.pkl"

CHANNELS_NAMES = [0, 1, 2, 3]
COUNTS = [1, 5, 10]
RESOLUTIONS = [0.2, 0.5, 0.8, 1.0, 1.2]


clusterings = {}

for ch_idx, ch_name in enumerate(CHANNELS_NAMES):
    # Prefer explicit train-only matrix if present, fall back to generic name
    p1 = os.path.join(MATRIX_DIR, f"confusion_trainonly_{ch_idx}.npy")
    p2 = os.path.join(MATRIX_DIR, f"confusion_{ch_idx}.npy")

    # Note: p1 and p2 are identical, checked with `diff /home/wmar/wmar_audio/outputs/confusion/matrices_new/confusion_0.npy /home/wmar/wmar_audio/outputs/confusion/matrices_new/confusion_trainonly_0.npy`

    mat_path = p1 if os.path.exists(p1) else p2
    if not os.path.exists(mat_path):
        raise FileNotFoundError(f"Confusion matrix not found for channel {ch_idx}: tried {p1} and {p2}")

    S = np.load(mat_path)
    per_channel = {}

    print(S.shape)

    for cnt in COUNTS:
        # apply min_count threshold by masking < cnt -> 0
        S_masked = np.where(S >= cnt, S, 0)
        g = ig.Graph.Weighted_Adjacency(S_masked.tolist(), mode='directed')

        for res in RESOLUTIONS:
            part = leidenalg.find_partition(
                g,
                leidenalg.RBConfigurationVertexPartition,
                weights=g.es['weight'],
                resolution_parameter=res,
                seed=27
            )
            labels = np.array(part.membership)
            # reindex to 0..K-1
            _, labels = np.unique(labels, return_inverse=True)
            per_channel[(cnt, res)] = labels

    clusterings[ch_name] = per_channel

# Save everything into one pickle
with open(CLUSTERS_PKL, "wb") as f:
    pickle.dump(clusterings, f)