import os
import random
import numpy as np
import torch


def set_seed(seed):
    if seed is None:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def log(msg):
    print(str(msg), flush=True)


def stratified_pick_per_class(y, per_class, num_classes=10, rng=np.random):
    y = np.asarray(y)
    idx_by_cls = {c: np.where(y == c)[0] for c in range(num_classes)}
    picked = []
    for c in range(num_classes):
        idx_c = idx_by_cls[c]
        if len(idx_c) < per_class:
            raise ValueError(f"Class {c} has only {len(idx_c)} samples, cannot pick {per_class}.")
        picked.append(rng.choice(idx_c, size=per_class, replace=False))
    return np.concatenate(picked, axis=0)


