import os, re, argparse, gc, pickle
import numpy as np, pandas as pd
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

def load_embeddings(path):
    pl = path.lower()
    if pl.endswith('.npy'): return np.load(path, mmap_mode='r')
    with open(path, 'rb') as f:
        obj = pickle.load(f)
    return np.asarray(obj if not isinstance(obj, dict) else obj.get('embeddings', obj))

def extract_number(filename):
    m = re.search(r'(\d+)', filename)
    return int(m.group(1)) if m else -1

def mean_offdiag_cosine_blockwise(X, block=4096):
    X = np.asarray(X)
    n = X.shape[0]
    if n <= 1: return 1.0
    norms = np.linalg.norm(X, axis=1, keepdims=True); norms[norms==0] = 1.0
    Y = X / norms
    total = 0.0
    for i in range(0, n, block):
        Yi = Y[i:i+block]
        total += float(np.sum(Yi @ Y.T))
    return (total - n) / (n*(n-1))

def compute_metrics_streaming(emb, labels, row_idx, m_near):
    clusters = np.unique(labels)
    members = {}
    for k in tqdm(clusters, desc='members'): members[int(k)] = row_idx[labels==k]
    cohesion, purity, centroids = {}, {}, {}
    for k, idx in tqdm(members.items(), desc='metrics'):
        Xk = emb[idx]; n = Xk.shape[0]
        if n <= 50:
            d = pairwise_distances(Xk); cohesion[k] = (d.sum() - np.trace(d)) / (n*(n-1))
        else:
            s = np.random.choice(n, 50, replace=False)
            d = pairwise_distances(Xk[s]); cohesion[k] = (d.sum() - np.trace(d)) / (50*49)
        purity[k] = mean_offdiag_cosine_blockwise(Xk)
        centroids[k] = Xk.mean(0)
        del Xk; gc.collect()
    ks = np.array(sorted(centroids.keys()))
    C = np.vstack([centroids[k] for k in ks])
    S = {k: cohesion[k] for k in ks}
    nbrs = NearestNeighbors(n_neighbors=m_near+1, metric='euclidean').fit(C)
    dist, idxs = nbrs.kneighbors(C)
    R = {}
    for i, ki in enumerate(ks):
        nnj, nnd = idxs[i][1:], dist[i][1:]
        ratios = [0.0 if d==0 else (S[ki] + S[ks[j]])/d for d, j in zip(nnd, nnj)]
        R[int(ki)] = float(max(ratios)) if ratios else 0.0
    return cohesion, purity, R

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--emb-file', type=str, required=True)
    ap.add_argument('--csv-dir', type=str, required=True)
    ap.add_argument('--out-dir', type=str, default='outputs/metrics')
    ap.add_argument('--valid-dir', type=str, default='outputs/valid')
    ap.add_argument('--min-size', type=int, default=5)
    ap.add_argument('--max-size', type=int, default=10000)
    ap.add_argument('--m-near', type=int, default=10)
    ap.add_argument('--use-auto-thr', action='store_true', help='use mean±2σ thresholds')
    ap.add_argument('--max-level', type=int, default=10)
    args = ap.parse_args()

    emb = load_embeddings(args.emb_file)
    files = sorted([f for f in os.listdir(args.csv_dir) if f.endswith('.csv')], key=extract_number)[:args.max_level]

    os.makedirs(args.out_dir, exist_ok=True); os.makedirs(args.valid_dir, exist_ok=True)

    for f in files:
        df = pd.read_csv(os.path.join(args.csv_dir, f), usecols=['cluster'])
        counts = df['cluster'].value_counts(sort=False)
        valid = counts[(counts>=args.min_size) & (counts<=args.max_size)].index
        mask = df['cluster'].isin(valid)
        labels = df.loc[mask, 'cluster'].to_numpy()
        row_idx = np.nonzero(mask.values)[0]

        if len(valid) == 0:
            print(f"[SKIP] {f}: no clusters in size range"); continue

        coh, pur, Ri = compute_metrics_streaming(emb, labels, row_idx, args.m_near)
        lvl = extract_number(f)
        metrics_csv = os.path.join(args.out_dir, f"cluster_metrics_{lvl}.csv")
        pd.DataFrame({
            'cluster_id': list(coh.keys()), 'cohesion': list(coh.values()),
            'purity': list(pur.values()), 'R_i': list(Ri.values())
        }).astype({'cluster_id': int}).sort_values('cluster_id').to_csv(metrics_csv, index=False)

        if args.use-auto-thr:
            dfm = pd.read_csv(metrics_csv)
            coh_thr = float(dfm['cohesion'].mean() + 2*dfm['cohesion'].std(ddof=0))
            ri_thr  = float(dfm['R_i'].mean()      + 2*dfm['R_i'].std(ddof=0))
            pur_thr = float(dfm['purity'].mean()   - 2*dfm['purity'].std(ddof=0))
        else:
            coh_thr, ri_thr, pur_thr = 1.0, 2.0, 0.7

        thr_csv = os.path.join(args.out_dir, f"cluster_metrics_{lvl}_thr.csv")
        pd.DataFrame([{'cohesion_thr':coh_thr, 'purity_thr':pur_thr, 'R_i_thr':ri_thr}]).to_csv(thr_csv, index=False)

        dfm = pd.read_csv(metrics_csv)
        dfm = dfm[dfm['cluster_id'].isin(valid)]
        good = dfm[(dfm['cohesion']<=coh_thr) & (dfm['purity']>=pur_thr) & (dfm['R_i']<=ri_thr)]['cluster_id'].astype(int).sort_values().tolist()
        with open(os.path.join(args.valid_dir, f"good_clusters_{lvl}.txt"), 'w') as w:
            for cid in good: w.write(f"{cid}\n")
        print(f"[LEVEL {lvl}] kept {len(good)} clusters")

if __name__ == '__main__':
    main()
