# generate_datasets.py
import numpy as np
import pandas as pd
import argparse

def generate_separated_clustered_points(dim, num_clusters, total_points):
    """
    Generate separated clustered points in the given dimension.
    Points are distributed across clusters, with each cluster centered at different positions
    separated by 10 units along each dimension for clear separation.
    Gaussian noise is added to simulate natural spread within clusters.
    
    Parameters:
    - dim: int, dimension of the points
    - num_clusters: int, number of clusters
    - total_points: int, total number of points across all clusters
    
    Returns:
    - np.array of shape (total_points, dim) containing all points
    """
    if num_clusters <= 0 or total_points <= 0:
        raise ValueError("Number of clusters and total points must be positive.")
    
    # Distribute points as evenly as possible across clusters
    points_per_cluster = total_points // num_clusters
    remainder = total_points % num_clusters
    
    points = []
    for i in range(num_clusters):
        # Center each cluster at positions separated by 10 units along all dimensions
        center = np.ones(dim) * i * 10
        # Assign points: base + extra for remainder
        cluster_size = points_per_cluster + (1 if i < remainder else 0)
        # Generate points with Gaussian noise (mean=0, std=1.0) for visible but clustered spread
        cluster_points = center + np.random.randn(cluster_size, dim) * 1.0
        points.append(cluster_points)
    
    # Stack all cluster points into a single array
    return np.vstack(points)

if __name__ == "__main__":
    # Parse command-line arguments for flexibility
    parser = argparse.ArgumentParser(description="Generate point set datasets in various dimensions and save as CSV.")
    parser.add_argument("--dims", type=str, default="1,2,3,4", 
                        help="Comma-separated list of dimensions (e.g., '1,2,3,4')")
    parser.add_argument("--num_clusters", type=int, default=20, 
                        help="Number of clusters for each dataset (default: 20)")
    parser.add_argument("--total_points", type=int, default=250000, 
                        help="Total number of points per dataset (default: 250000)")
    args = parser.parse_args()
    
    # Convert dims string to list of integers
    dims = [int(d.strip()) for d in args.dims.split(",")]
    
    for dim in dims:
        # Generate points
        points = generate_separated_clustered_points(dim, args.num_clusters, args.total_points)
        
        # Create column names based on dimension (x1, x2, ..., xd)
        cols = [f"x{i+1}" for i in range(dim)]
        
        # Create DataFrame and save to CSV
        df = pd.DataFrame(points, columns=cols)
        filename = f"points_{dim}d.csv"
        df.to_csv(filename, index=False)
        print(f"Generated {filename} with {args.total_points} points in {dim}D, divided into {args.num_clusters} clusters.")