import os
import math
import random
import numpy as np
import torch

IN_DIR   = "preprocessed_dataset/cellular_niche"
OUT_DIR  = IN_DIR
SIZES    = [10, 50, 100, 200]
NUM_PAIRS = 10000
SEED     = 28032003
DTYPE    = torch.float32


rng = np.random.default_rng(SEED)
random.seed(SEED)
torch.manual_seed(SEED)


x_path = os.path.join(IN_DIR, "X_clouds.pt")
y_path = os.path.join(IN_DIR, "y_clouds.pt")
train_root = os.path.join(OUT_DIR, "train")
test_root  = os.path.join(OUT_DIR, "test")
os.makedirs(train_root, exist_ok=True)
os.makedirs(test_root, exist_ok=True)

X = torch.load(x_path, map_location="cpu").to(DTYPE)
y = torch.load(y_path, map_location="cpu").long()
N = int(X.shape[0])

labels_np = y.cpu().numpy()
uniq_labels = np.unique(labels_np)
label_to_idx = {int(l): np.where(labels_np == l)[0].tolist() for l in uniq_labels}

for K in SIZES:
    L = len(uniq_labels)

    need = {int(l): K // L for l in uniq_labels}
    remainder = K - sum(need.values())

    capacities = {int(l): len(label_to_idx[int(l)]) for l in uniq_labels}
    order = sorted(uniq_labels, key=lambda l: capacities[int(l)], reverse=True)
    for i in range(remainder):
        need[int(order[i % L])] += 1

    leftover = 0
    for l in uniq_labels:
        li = int(l)
        cap = capacities[li]
        if need[li] > cap:
            leftover += need[li] - cap
            need[li] = cap

    if leftover > 0:
        pool = [[int(l), capacities[int(l)] - need[int(l)]] for l in uniq_labels if capacities[int(l)] - need[int(l)] > 0]
        pool.sort(key=lambda x: x[1], reverse=True)
        idx_pool = 0
        while leftover > 0 and pool:
            lbl, cap_left = pool[idx_pool]
            take = min(cap_left, leftover)
            need[lbl] += take
            pool[idx_pool][1] -= take
            leftover -= take
            idx_pool = (idx_pool + 1) % len(pool)
            if pool[(idx_pool - 1) % len(pool)][1] == 0:
                pool = [p for p in pool if p[1] > 0]
                idx_pool = 0 if not pool else idx_pool % len(pool)

    picked = []
    for l in uniq_labels:
        li = int(l)
        pool = label_to_idx[li]
        k = min(need[li], len(pool))
        if k > 0:
            chosen = rng.choice(pool, size=k, replace=False).tolist()
            picked.extend(chosen)

    if len(picked) < K:
        remaining = list(set(range(N)) - set(picked))
        if remaining:
            extra = rng.choice(remaining, size=min(K - len(picked), len(remaining)), replace=False).tolist()
            picked.extend(extra)

    if len(picked) > K:
        picked = rng.choice(picked, size=K, replace=False).tolist()

    idx_tensor = torch.tensor(picked, dtype=torch.long)
    perm = torch.randperm(idx_tensor.numel())
    X_sub = X.index_select(0, idx_tensor)[perm]
    y_sub = y.index_select(0, idx_tensor)[perm]

    out_dir_k = os.path.join(train_root, f"num_samples_{K}")
    os.makedirs(out_dir_k, exist_ok=True)
    torch.save(X_sub, os.path.join(out_dir_k, "samples.pt"))
    torch.save(y_sub, os.path.join(out_dir_k, "y_train.pt"))

max_pairs = N // 2
P = max_pairs if (NUM_PAIRS is None or NUM_PAIRS > max_pairs) else NUM_PAIRS
perm_all = torch.tensor(rng.permutation(N), dtype=torch.long)
X_shuf = X.index_select(0, perm_all)
pcs1 = X_shuf[:P]
pcs2 = X_shuf[P:2*P]

out_dir_test = os.path.join(test_root, f"num_pairs_{P}")
os.makedirs(out_dir_test, exist_ok=True)
torch.save(pcs1, os.path.join(out_dir_test, "pcs1.pt"))
torch.save(pcs2, os.path.join(out_dir_test, "pcs2.pt"))
