import scanpy as sc
import numpy as np
import scipy.sparse as sp
from sklearn.neighbors import KDTree
import torch
import os


def save_train_sampled_data(parent_path, num_samples, seed=42):
    data = torch.load(parent_path + "/all_samples.pt")
    features = data["features"]
    labels = data["labels"]
    label_classes = data["label_classes"]
    centers = data["centers"]
    num_classes = len(label_classes)

    base = num_samples // num_classes
    remainder = num_samples % num_classes

    rng = np.random.default_rng(seed)
    class_sizes = []
    class_indices = []
    for cls_id in range(num_classes):
        idx = torch.nonzero(labels == cls_id, as_tuple=False).squeeze(1).cpu().numpy()
        class_indices.append(idx)
        class_sizes.append(len(idx))
    order = np.argsort(-np.asarray(class_sizes))

    target_per_class = np.full(num_classes, base, dtype=int)
    if remainder > 0:
        target_per_class[order[:remainder]] += 1

    selected_indices = []
    for cls_id in range(num_classes):
        idx = class_indices[cls_id]
        t = target_per_class[cls_id]
        if len(idx) == 0:
            continue
        if len(idx) < t:
            chosen = idx
        else:
            chosen = rng.choice(idx, size=t, replace=False)
        selected_indices.append(chosen)

    selected_indices = np.concatenate(selected_indices)
    rng.shuffle(selected_indices)

    features_sub = features[selected_indices]
    labels_sub = labels[selected_indices]
    centers_sub = centers[selected_indices]

    output_path = os.path.join(parent_path, f"num_samples_{num_samples}")
    os.makedirs(output_path, exist_ok=True)
    torch.save(features_sub.cpu(), output_path + "/samples.pt")

def _targets_minwise(labels, num_samples, verbose=True):
    labels_np = labels.cpu().numpy().astype(np.int64)
    num_classes = int(labels_np.max()) + 1
    class_indices = [np.flatnonzero(labels_np == c) for c in range(num_classes)]
    sizes = np.array([len(ix) for ix in class_indices], dtype=np.int64)
    targets = np.zeros(num_classes, dtype=np.int64)
    if sizes.sum() == 0:
        return targets
    base_even = num_samples // num_classes
    base = int(min(sizes.min(), base_even))
    initial = np.minimum(sizes, base)
    targets += initial
    allocated = int(initial.sum())
    remainder = num_samples - allocated
    order_desc = np.argsort(-sizes)
    while remainder > 0:
        room = sizes - targets
        room_classes = np.flatnonzero(room > 0)
        if len(room_classes) == 0:
            break
        num_room = len(room_classes)
        base2_even = remainder // num_room
        if base2_even == 0:
            count = 0
            for c in order_desc:
                if room[c] > 0:
                    targets[c] += 1
                    count += 1
                    if count == remainder:
                        break
            remainder -= count
            continue
        base2 = int(min(room[room_classes].min(), base2_even))
        if base2 <= 0:
            count = 0
            for c in order_desc:
                if room[c] > 0:
                    targets[c] += 1
                    count += 1
                    if count == remainder:
                        break
            remainder -= count
        else:
            targets[room_classes] += base2
            remainder -= base2 * num_room
    return targets

def save_test_sampled_data(parent_path, num_pairs, seed=42):
    data = torch.load(parent_path + "/all_samples.pt")
    features = data["features"]
    labels = data["labels"]
    label_classes = data["label_classes"]
    centers = data["centers"]
    num_classes = len(label_classes)

    rng = np.random.default_rng(seed)
    num_samples = int(num_pairs) * 2
    targets = _targets_minwise(labels, num_samples, verbose=True)
    total_plan = int(targets.sum())

    labels_np = labels.cpu().numpy().astype(np.int64)
    class_indices = [np.flatnonzero(labels_np == c) for c in range(num_classes)]

    chosen = []
    for c in range(num_classes):
        t = int(targets[c])
        if t <= 0:
            continue
        pool = class_indices[c]
        if len(pool) < t:
            raise RuntimeError(f"Class {c} thiếu: cần {t}, có {len(pool)}.")
        sub = rng.choice(pool, size=t, replace=False)
        chosen.append(sub)

    selected_indices = np.concatenate(chosen)
    rng.shuffle(selected_indices)
    if selected_indices.shape[0] < num_samples:
        raise RuntimeError(f"Đã chọn {selected_indices.shape[0]} < {num_samples}. Không đủ để chia đôi.")
    selected_indices = selected_indices[:num_samples]

    features_sub = features[selected_indices]
    labels_sub = labels[selected_indices]
    centers_sub = centers[selected_indices]

    pcs1_features = features_sub[:num_pairs]
    pcs2_features = features_sub[num_pairs:]
    pcs1_labels = labels_sub[:num_pairs]
    pcs2_labels = labels_sub[num_pairs:]
    pcs1_centers = centers_sub[:num_pairs]
    pcs2_centers = centers_sub[num_pairs:]

    output_path = os.path.join(parent_path, "test", f"num_pairs_{num_pairs}")
    os.makedirs(output_path, exist_ok=True)
    torch.save(pcs1_features.cpu(), os.path.join(output_path, "pcs1.pt"))
    torch.save(pcs2_features.cpu(), os.path.join(output_path, "pcs2.pt"))
    torch.save(pcs1_labels.cpu(),   os.path.join(output_path, "pcs1_labels.pt"))
    torch.save(pcs2_labels.cpu(),   os.path.join(output_path, "pcs2_labels.pt"))
    torch.save(pcs1_centers,        os.path.join(output_path, "pcs1_centers.pt"))
    torch.save(pcs2_centers,        os.path.join(output_path, "pcs2_centers.pt"))
    torch.save(label_classes,       os.path.join(output_path, "label_classes.pt"))

if __name__ == "__main__":
    parent_path = "preprocessed_dataset/merfish"
    for num_samples in [10, 50, 100, 200]:
        save_train_sampled_data(parent_path=parent_path, num_samples=num_samples, seed=42)
    save_test_sampled_data(parent_path=parent_path, num_pairs=10000, seed=42)
