import os, time, argparse, numpy as np, pickle

def detect_structure(path: str, dim: int):
    if path.lower().endswith('.npy'):
        try:
            arr = np.load(path, mmap_mode="r", allow_pickle=False)
            if arr.ndim == 2:
                assert arr.shape[1] == dim, f"dim mismatch: {arr.shape[1]} vs {dim}"
                return "2d", arr, arr.shape[0]
        except Exception:
            pass
        arr = np.load(path, allow_pickle=True)
        return "obj", arr, len(arr)
    else:
        with open(path, 'rb') as f:
            obj = pickle.load(f)
        return "obj", obj, len(obj)

def pick_index_dtype(dim: int):
    return np.int16 if dim <= 32767 else np.int32

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--input', type=str, required=True)
    ap.add_argument('--out-dir', type=str, default='outputs/neuron_sets')
    ap.add_argument('--dim', type=int, required=True)
    ap.add_argument('--k', type=float, default=2.0)
    ap.add_argument('--batch-size', type=int, default=10000)
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    base = os.path.splitext(os.path.basename(args.input))[0]
    out_path = os.path.join(args.out_dir, f"{base}_neuron_set.npy")

    kind, arr, N = detect_structure(args.input, args.dim)
    print(f"[INFO] Detected kind={kind}, N={N}")
    idx_dtype = pick_index_dtype(args.dim)

    results = []
    t0 = time.time()

    if kind == '2d':
        for s in range(0, N, args.batch_size):
            e = min(s+args.batch_size, N)
            X = arr[s:e].astype(np.float32, copy=False)
            m, sd = X.mean(1), X.std(1)
            thr = m + args.k*sd
            mask = X > thr[:, None]
            results.extend([np.flatnonzero(mask[i]).astype(idx_dtype, copy=False) for i in range(mask.shape[0])])
            print(f"[2d] {e}/{N} ({e/N*100:.2f}%) in {time.time()-t0:.1f}s")
    else:
        for i in range(N):
            v = np.asarray(arr[i], dtype=np.float32)
            thr = float(v.mean() + args.k*v.std())
            results.append(np.flatnonzero(v > thr).astype(idx_dtype, copy=False))
            if (i+1) % 100000 == 0:
                print(f"[obj] {i+1}/{N} ({(i+1)/N*100:.2f}%) in {time.time()-t0:.1f}s")

    out = np.empty(N, dtype=object); out[:] = results
    np.save(out_path, out, allow_pickle=True)
    print(f"[DONE] saved → {out_path}")

if __name__ == '__main__':
    main()
