import numpy as np
from numpy.random import default_rng
from scipy.stats import truncnorm

def generate_x_expc_size():
    pass

def generate_x_equal_size(rng, random_seed=0, d=100, l=2, p=0.5, set_size=100,train_set=None):
    half = d // 2
    results = set()
    while len(results) < set_size:
        combination = tuple(sorted(rng.choice(range(d), size=l, replace=False)))
        if combination not in results and combination not in train_set:
            results.add(combination)
    return list(results)

def generate_theta(mode="normal",d=100, random_seed=42, mu=0.0, sigma=1.0, alpha=1, beta=1):
    rng = default_rng(random_seed)
    y_min, y_max = 0, 1
    if mode == "normal":
        theta = rng.normal(mu, sigma, d)
    elif mode == "mix_normal":
        weights = [0.5, 0.5]
        means = [mu, mu+1]  
        components = rng.choice(len(weights), size=d, p=weights)
        theta = rng.normal(loc=[means[c] for c in components],
                        scale=sigma)
    elif mode == "truncated_normal":
        a, b = (y_min - mu) / sigma, (y_max - mu) / sigma
        truncated_normal = truncnorm(a, b, loc=mu, scale=sigma)
        theta = truncated_normal.rvs(size=d)
    elif mode == "uniform":
        theta = rng.uniform(0, 1, d)
    elif mode == "binary":
        theta = rng.choice([0, 1], size=d)
    elif mode == "beta":
        theta = rng.beta(alpha, beta, d)

    return theta