import os, re, argparse, numpy as np, pandas as pd, gc

LAYER_FILES = [
    ("layer1", "layer1_2_avgpool_neuron_set.npy"),
    ("layer2", "layer2_3_avgpool_neuron_set.npy"),
    ("layer3", "layer3_5_avgpool_neuron_set.npy"),
    ("layer4", "layer4_2_avgpool_neuron_set.npy"),
]

def level_from_name(name):
    m = re.search(r'level_(\d+)', os.path.basename(name))
    return int(m.group(1)) if m else None

def good_set(valid_dir, level):
    p = os.path.join(valid_dir, f"good_clusters_{level}.txt")
    if not os.path.exists(p): return set()
    with open(p,'r') as f:
        return set(int(x.strip()) for x in f if x.strip())

def common_indices(layer_arr, idxs, min_count):
    cnt = {}
    for i in idxs:
        s = set(layer_arr[i]) if hasattr(layer_arr[i], '__iter__') else {int(layer_arr[i])}
        for a in s: cnt[a] = cnt.get(a, 0) + 1
    return sorted([a for a,c in cnt.items() if c >= min_count])

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--finch-dir', type=str, default='outputs/clusters')
    ap.add_argument('--valid-dir', type=str, default='outputs/valid')
    ap.add_argument('--npy-dir',  type=str, default='outputs/neuron_sets')
    ap.add_argument('--out-dir',  type=str, default='outputs/cluster_neuron_sets')
    ap.add_argument('--max-level', type=int, default=10)
    ap.add_argument('--appearance-rate', type=float, default=0.75)
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    csvs = [os.path.join(args.finch-dir, f) for f in os.listdir(args.finch_dir)
            if f.startswith('clustered_captions_level_') and f.endswith('.csv')]
    lvls = sorted([(level_from_name(p), p) for p in csvs if level_from_name(p) is not None and 1 <= level_from_name(p) <= args.max_level])

    for level, path in lvls:
        good = good_set(args.valid_dir, level)
        if not good:
            print(f"[Level {level}] no valid clusters. skip."); continue

        df = pd.read_csv(path, usecols=['cluster'])
        clusters = sorted(set(df['cluster']) & good)
        if not clusters:
            print(f"[Level {level}] no overlap with good_clusters. skip."); continue

        idx_cache = {c: df.index[df['cluster']==c].to_numpy() for c in clusters}
        min_cache = {c: int(len(idx_cache[c]) * args.appearance_rate + 1e-8) for c in clusters}

        rows_by_cluster = {c: [] for c in clusters}
        for lname, fname in LAYER_FILES:
            npy_path = os.path.join(args.npy_dir, fname)
            if not os.path.exists(npy_path): raise FileNotFoundError(npy_path)
            layer_arr = np.load(npy_path, allow_pickle=True)
            for c in clusters:
                sel = common_indices(layer_arr, idx_cache[c], min_cache[c])
                rows_by_cluster[c].append({
                    "Layer": lname,
                    "Selected Filter Indices": ", ".join(map(str, sel))
                })
            del layer_arr; gc.collect()

        for c in clusters:
            out_csv = os.path.join(args.out_dir, f"level{level}_cluster_{c}.csv")
            pd.DataFrame(rows_by_cluster[c], columns=["Layer", "Selected Filter Indices"]).to_csv(out_csv, index=False)
            print(f"[SAVE] {out_csv}")

if __name__ == '__main__':
    main()
