import torch

# Generate one random sample for x1, repeat the single sample num_samples times
def generate_x1_with_uncertainty(num_samples, input_dimA, num_randoms):
    random_pool = torch.randn(num_randoms, input_dimA)    
    random_indices = torch.randint(0, num_randoms, (num_samples,))    
    x1 = random_pool[random_indices].clone().detach().requires_grad_(True)
    return x1, random_pool

def generate_x1_from_clusters(num_samples, input_dimA, num_clusters, points_per_cluster, cluster_std):
    cluster_centers = torch.randn(num_clusters, input_dimA)
    random_pool = torch.randn(num_clusters, points_per_cluster, input_dimA) * cluster_std + cluster_centers.unsqueeze(1)
    random_pool = random_pool.reshape(-1, input_dimA)
    random_indices = torch.randint(0, num_clusters * points_per_cluster, (num_samples,))
    x1 = random_pool[random_indices].clone().detach().requires_grad_(True)
    return x1, cluster_centers

def generate_x1_from_clusters_deter(num_samples, input_dimA, num_clusters, points_per_cluster, cluster_std):
    cluster_centers = torch.randn(num_clusters, input_dimA)
    random_pool = torch.randn(num_clusters, points_per_cluster, input_dimA) * cluster_std + cluster_centers.unsqueeze(1)
    random_pool = random_pool.reshape(-1, input_dimA)
    
    # Create a deterministic, incremental selection of indices
    total_points = num_clusters * points_per_cluster
    assert num_samples <= total_points, "num_samples exceeds the available points"
    
    # Generate incremental indices, cycling through clusters
    random_indices = torch.arange(num_samples) % total_points
    
    x1 = random_pool[random_indices].clone().detach().requires_grad_(True)
    return x1, cluster_centers

def generate_x1_with_random_walk(num_samples, input_dimA, perturbation_scale=0.01):
    initial_point = torch.randn(1, input_dimA)    
    x1 = initial_point.repeat(num_samples, 1)

    for i in range(1, num_samples):
        perturbation = torch.randn(1, input_dimA) * perturbation_scale        
        x1[i] = x1[i-1] + perturbation    
    x1.requires_grad_(True)
    return x1, initial_point

def generate_x1_grid(num_samples, input_dimA, grid_steps=None):
    points_per_dim = int(num_samples ** (1/input_dimA))
    if grid_steps is not None:
        points_per_dim = grid_steps
    axes = [torch.linspace(-1, 1, points_per_dim) for _ in range(input_dimA)]
    grid = torch.meshgrid(*axes, indexing='ij')
    points = torch.stack(grid, dim=-1)
    x1 = points.reshape(-1, input_dimA)
    x1 = x1[:num_samples, :]
    if x1.size(0) < num_samples:
        padding = x1[-1].unsqueeze(0).repeat(num_samples - x1.size(0), 1)
        x1 = torch.cat([x1, padding], dim=0)
    x1.requires_grad_(True)
    return x1, points