"""
DTLZ Test Functions for High-Dimensional Benchmarking.

This module implements the DTLZ family of test functions, which are standard
benchmarks in multi-objective optimization. These functions provide:
- Scalable design dimension (P)
- Scalable number of objectives (L)
- Analytically known forward mapping (design -> objectives)
- Known Pareto front geometry

Reference:
    Deb, K., Thiele, L., Laumanns, M., & Zitzler, E. (2005).
    "Scalable Test Problems for Evolutionary Multiobjective Optimization."
    In Evolutionary Multiobjective Optimization (pp. 105-145). Springer.

The key advantage for our inverse design experiments is that the forward
mapping is exact and analytical, allowing precise round-trip error evaluation.
"""

import torch
import numpy as np
from typing import Callable

DTLZ_DATASET_NAME = "dtlz"


# =============================================================================
# Sampling Strategies
# =============================================================================

def sample_dtlz2_uniform(
    num_samples: int,
    num_design_params: int,
    num_objectives: int,
    generator: torch.Generator = None,
) -> torch.Tensor:
    """
    Standard uniform sampling from [0, 1]^P.

    This is the naive approach that fails in high dimensions due to
    concentration of measure - the distance parameters concentrate
    away from 0.5, meaning we never sample near the Pareto front.

    Args:
        num_samples: Number of samples to generate.
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        generator: Random generator for reproducibility.

    Returns:
        Design parameters tensor of shape (num_samples, P).
    """
    return torch.rand(num_samples, num_design_params, generator=generator)


def sample_dtlz2_stratified(
    num_samples: int,
    num_design_params: int,
    num_objectives: int,
    g_max: float = 2.0,
    generator: torch.Generator = None,
) -> tuple[torch.Tensor, dict]:
    """
    Stratified sampling for DTLZ2 that ensures uniform coverage of g values.

    In high dimensions, uniform sampling fails because the distance function
    g(x_M) = sum((x_i - 0.5)^2) concentrates away from 0 due to the curse of
    dimensionality. This function uses a stratified approach:

    1. Position parameters (x_1, ..., x_{L-1}): Sampled uniformly to cover
       all trade-off angles on the Pareto front.

    2. Distance parameters (x_L, ..., x_P): Sampled to achieve a target g value
       by using Dirichlet distribution to split g across coordinates, then
       computing the corresponding deviations from 0.5.

    Args:
        num_samples: Number of samples to generate.
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        g_max: Maximum g value to sample. For DTLZ2, g_max=2.0 gives objectives
               up to ~3 (since f = (1+g) * angular_terms).
        generator: Random generator for reproducibility.

    Returns:
        Tuple of (designs, stats) where:
            - designs: Tensor of shape (num_samples, P)
            - stats: Dictionary with sampling statistics (clipping info, etc.)

    Mathematical details:
        We want to sample g uniformly from [0, g_max], then distribute this
        across d = P - L + 1 distance coordinates such that:
            g = sum_{i=L}^{P} (x_i - 0.5)^2

        Using Dirichlet(1, ..., 1) to get fractions that sum to 1:
            (x_i - 0.5)^2 = fraction_i * g
            x_i = 0.5 ± sqrt(fraction_i * g)

        The sign is chosen randomly for each coordinate.
    """
    P = num_design_params
    L = num_objectives
    d = P - L + 1  # Number of distance parameters

    # Use numpy for Dirichlet sampling, then convert to torch
    # (torch doesn't have built-in Dirichlet sampling)
    if generator is not None:
        # Extract seed from torch generator for numpy
        np_seed = generator.initial_seed() % (2**32)
        np_rng = np.random.default_rng(np_seed)
    else:
        np_rng = np.random.default_rng()

    # 1. Sample position parameters uniformly (determines angle on Pareto front)
    X_position = torch.from_numpy(
        np_rng.random((num_samples, L - 1))
    ).float()

    # 2. Sample target g values uniformly from [0, g_max]
    g_target = torch.from_numpy(
        np_rng.random(num_samples) * g_max
    ).float()

    # 3. Use Dirichlet to distribute g across distance coordinates
    # Dirichlet(1, ..., 1) gives uniform distribution on simplex
    alpha = np.ones(d)
    fractions = torch.from_numpy(
        np_rng.dirichlet(alpha, size=num_samples)
    ).float()  # (num_samples, d)

    # 4. Compute deviations: (x_i - 0.5)^2 = fraction_i * g
    deviations_sq = fractions * g_target[:, None]
    deviations = torch.sqrt(deviations_sq)

    # 5. Random signs for each deviation
    signs = torch.sign(
        torch.from_numpy(np_rng.random((num_samples, d))).float() - 0.5
    )

    # 6. Compute distance parameters
    X_distance = 0.5 + signs * deviations

    # 7. Track clipping statistics before clipping
    below_zero = (X_distance < 0).sum().item()
    above_one = (X_distance > 1).sum().item()
    total_elements = X_distance.numel()

    # 8. Clip to valid range [0, 1]
    X_distance = torch.clamp(X_distance, 0, 1)

    # 9. Combine position and distance parameters
    designs = torch.cat([X_position, X_distance], dim=1)

    # Compute statistics
    stats = {
        "clipped_below_zero": below_zero,
        "clipped_above_one": above_one,
        "total_distance_elements": total_elements,
        "clip_percentage": 100.0 * (below_zero + above_one) / total_elements,
        "g_max": g_max,
        "num_distance_params": d,
    }

    return designs, stats


# =============================================================================
# DTLZ Test Functions
# =============================================================================

def dtlz2_forward(x: torch.Tensor, num_objectives: int = 3) -> torch.Tensor:
    """
    DTLZ2 test function: maps design parameters to objectives.

    The DTLZ2 function creates a spherical Pareto front in objective space.
    It is widely used because:
    - The Pareto front is the first quadrant of a unit hypersphere
    - All objectives are bounded in [0, ~2] for x in [0,1]^P
    - The problem is scalable in both design dimension and objectives

    Args:
        x: Design parameters tensor of shape (batch_size, P) where P >= num_objectives.
           All values should be in [0, 1].
        num_objectives: Number of objectives L. Default is 3.

    Returns:
        Objectives tensor of shape (batch_size, num_objectives).

    Mathematical formulation:
        Let x = (x_1, ..., x_P) with P >= L (num_objectives)

        g(x_M) = sum_{i=L}^{P} (x_i - 0.5)^2

        f_1(x) = (1 + g) * prod_{i=1}^{L-1} cos(x_i * pi/2)
        f_2(x) = (1 + g) * prod_{i=1}^{L-2} cos(x_i * pi/2) * sin(x_{L-1} * pi/2)
        ...
        f_L(x) = (1 + g) * sin(x_1 * pi/2)

    Example:
        >>> x = torch.rand(100, 12)  # 100 samples, 12 design parameters
        >>> y = dtlz2_forward(x, num_objectives=3)  # 3 objectives
        >>> y.shape
        torch.Size([100, 3])
    """
    batch_size, P = x.shape
    L = num_objectives

    if P < L:
        raise ValueError(f"Design dimension P={P} must be >= num_objectives L={L}")

    # Distance parameters: x_M = x[L-1:] (indices L-1 to P-1 inclusive)
    x_M = x[:, L-1:]  # Shape: (batch_size, P - L + 1)

    # Compute g(x_M) = sum of squared deviations from 0.5
    g = torch.sum((x_M - 0.5) ** 2, dim=1, keepdim=True)  # Shape: (batch_size, 1)

    # Position parameters: x_1, ..., x_{L-1}
    x_pos = x[:, :L-1]  # Shape: (batch_size, L-1)

    # Compute objectives
    # f_m depends on cumulative products of cos and sin terms
    objectives = torch.zeros(batch_size, L, device=x.device, dtype=x.dtype)

    # Precompute cos and sin of position parameters scaled by pi/2
    cos_x = torch.cos(x_pos * (np.pi / 2))  # Shape: (batch_size, L-1)
    sin_x = torch.sin(x_pos * (np.pi / 2))  # Shape: (batch_size, L-1)

    for m in range(L):
        # f_m = (1 + g) * prod_{i=1}^{L-1-m} cos(x_i * pi/2) * [sin(x_{L-m} * pi/2) if m > 0]
        prod_term = torch.ones(batch_size, device=x.device, dtype=x.dtype)

        # Multiply by cos terms: cos(x_1), cos(x_2), ..., cos(x_{L-1-m})
        num_cos = L - 1 - m
        if num_cos > 0:
            prod_term = prod_term * torch.prod(cos_x[:, :num_cos], dim=1)

        # Multiply by sin term if m > 0: sin(x_{L-m})
        if m > 0:
            sin_idx = L - 1 - m  # Index into sin_x
            prod_term = prod_term * sin_x[:, sin_idx]

        objectives[:, m] = (1 + g.squeeze()) * prod_term

    return objectives


def dtlz1_forward(x: torch.Tensor, num_objectives: int = 3) -> torch.Tensor:
    """
    DTLZ1 test function: maps design parameters to objectives.

    The DTLZ1 function creates a linear hyperplane Pareto front.
    It has many local Pareto-optimal fronts, making it challenging.

    Args:
        x: Design parameters tensor of shape (batch_size, P) where P >= num_objectives.
           All values should be in [0, 1].
        num_objectives: Number of objectives L. Default is 3.

    Returns:
        Objectives tensor of shape (batch_size, num_objectives).

    Mathematical formulation:
        g(x_M) = 100 * [k + sum_{i in x_M} ((x_i - 0.5)^2 - cos(20*pi*(x_i - 0.5)))]
        where k = |x_M| = P - L + 1

        f_1(x) = 0.5 * x_1 * x_2 * ... * x_{L-1} * (1 + g)
        f_2(x) = 0.5 * x_1 * x_2 * ... * (1 - x_{L-1}) * (1 + g)
        ...
        f_L(x) = 0.5 * (1 - x_1) * (1 + g)
    """
    batch_size, P = x.shape
    L = num_objectives

    if P < L:
        raise ValueError(f"Design dimension P={P} must be >= num_objectives L={L}")

    # Distance parameters
    x_M = x[:, L-1:]
    k = P - L + 1

    # Compute g(x_M) with multimodal landscape
    g_term = (x_M - 0.5) ** 2 - torch.cos(20 * np.pi * (x_M - 0.5))
    g = 100 * (k + torch.sum(g_term, dim=1, keepdim=True))

    # Position parameters
    x_pos = x[:, :L-1]

    # Compute objectives
    objectives = torch.zeros(batch_size, L, device=x.device, dtype=x.dtype)

    for m in range(L):
        prod_term = torch.ones(batch_size, device=x.device, dtype=x.dtype)

        # Multiply by x_i terms
        num_x = L - 1 - m
        if num_x > 0:
            prod_term = prod_term * torch.prod(x_pos[:, :num_x], dim=1)

        # Multiply by (1 - x_{L-m}) if m > 0
        if m > 0:
            prod_term = prod_term * (1 - x_pos[:, L - 1 - m])

        objectives[:, m] = 0.5 * (1 + g.squeeze()) * prod_term

    return objectives


# =============================================================================
# Ground Truth Surrogate Functions
# =============================================================================

def make_dtlz_surrogate(
    function_name: str = "dtlz2",
    num_objectives: int = 3,
) -> Callable[[torch.Tensor], torch.Tensor]:
    """
    Create a ground truth surrogate function for DTLZ problems.

    Unlike the gas turbine dataset which uses pre-trained neural network
    surrogates, DTLZ functions have analytical forward mappings. This function
    returns the appropriate analytical forward function.

    Args:
        function_name: Name of the DTLZ function ("dtlz1" or "dtlz2").
        num_objectives: Number of objectives.

    Returns:
        A callable that maps design parameters to objectives.

    Example:
        >>> surrogate = make_dtlz_surrogate("dtlz2", num_objectives=3)
        >>> x = torch.rand(10, 12)
        >>> y = surrogate(x)
    """
    if function_name == "dtlz1":
        return lambda x: dtlz1_forward(x, num_objectives=num_objectives)
    elif function_name == "dtlz2":
        return lambda x: dtlz2_forward(x, num_objectives=num_objectives)
    else:
        raise ValueError(f"Unknown DTLZ function: {function_name}")


# =============================================================================
# Dataset Class
# =============================================================================

class DTLZDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset for DTLZ benchmark functions.

    This dataset generates samples on-the-fly by:
    1. Sampling design parameters uniformly from [0, 1]^P
    2. Computing objectives using the analytical DTLZ forward function

    For reproducibility, a fixed seed can be provided which pre-generates
    all samples at initialization.

    Attributes:
        num_samples: Number of samples in the dataset
        num_design_params: Dimension of design space (P)
        num_objectives: Number of objectives (L)
        function_name: Name of DTLZ function ("dtlz1" or "dtlz2")
        designs: Pre-generated design parameters (num_samples, P)
        labels: Pre-computed objectives (num_samples, L)

    Example:
        >>> dataset = DTLZDataset(
        ...     split="train",
        ...     num_samples=10000,
        ...     num_design_params=12,
        ...     num_objectives=3,
        ...     function_name="dtlz2",
        ...     seed=42
        ... )
        >>> x, y = dataset[0]
        >>> x.shape, y.shape
        (torch.Size([12]), torch.Size([3]))
    """

    def __init__(
        self,
        split: str,
        num_samples: int = None,
        num_design_params: int = 12,
        num_objectives: int = 3,
        function_name: str = "dtlz2",
        seed: int = None,
        transform: Callable = None,
        normalize_labels: bool = True,
        sampling_strategy: str = "stratified",
        g_max: float = 2.0,
    ):
        """
        Initialize the DTLZ dataset.

        Args:
            split: Dataset split ("train", "val", or "test"). Used to set
                   default sample sizes and seeds if not provided.
            num_samples: Number of samples. If None, uses defaults based on split:
                        - train: 100000
                        - val: 5000
                        - test: 5000
            num_design_params: Dimension of design space P (default: 12).
            num_objectives: Number of objectives L (default: 3).
            function_name: DTLZ function to use ("dtlz1" or "dtlz2").
            seed: Random seed for reproducibility. If None, uses split-based seed.
            transform: Optional transform to apply to (x, y) pairs.
            normalize_labels: Whether to normalize labels to approximately [0, 1].
                             For DTLZ2, divides by 2 (max value ~2 when g=1).
            sampling_strategy: Sampling strategy to use:
                - "uniform": Standard uniform sampling (fails in high dimensions)
                - "stratified": Stratified sampling that controls g distribution
            g_max: Maximum g value for stratified sampling (default: 2.0).
        """
        self.split = split
        self.num_design_params = num_design_params
        self.num_objectives = num_objectives
        self.function_name = function_name
        self.transform = transform
        self.normalize_labels = normalize_labels
        self.sampling_strategy = sampling_strategy
        self.g_max = g_max

        # Set default sample sizes based on split
        if num_samples is None:
            default_sizes = {"train": 100000, "val": 5000, "test": 5000}
            if split not in default_sizes:
                raise ValueError(f"Unknown split: {split}. Use 'train', 'val', or 'test'.")
            num_samples = default_sizes[split]
        self.num_samples = num_samples

        # Set seed based on split for reproducibility
        if seed is None:
            seed_map = {"train": 42, "val": 43, "test": 44}
            seed = seed_map.get(split, 42)
        self.seed = seed

        # Get forward function
        self.forward_fn = make_dtlz_surrogate(function_name, num_objectives)

        # Pre-generate all data for consistency
        self._generate_data()

    def _generate_data(self):
        """Generate design parameters and compute objectives."""
        generator = torch.Generator().manual_seed(self.seed)

        # Sample designs based on selected strategy
        if self.sampling_strategy == "uniform":
            self.designs = sample_dtlz2_uniform(
                self.num_samples,
                self.num_design_params,
                self.num_objectives,
                generator=generator,
            )
            self.sampling_stats = {"strategy": "uniform"}
        elif self.sampling_strategy == "stratified":
            self.designs, self.sampling_stats = sample_dtlz2_stratified(
                self.num_samples,
                self.num_design_params,
                self.num_objectives,
                g_max=self.g_max,
                generator=generator,
            )
            self.sampling_stats["strategy"] = "stratified"
        else:
            raise ValueError(f"Unknown sampling strategy: {self.sampling_strategy}")

        # Compute objectives using analytical forward function
        self.labels = self.forward_fn(self.designs)

        # Optionally normalize labels
        if self.normalize_labels:
            # For DTLZ2, objectives are in [0, 1+g] where g >= 0
            # When x_M = 0.5, g = 0 and objectives are in [0, 1]
            # With stratified sampling and g_max=2.0, max objective is ~3
            # We normalize by dividing by (1 + g_max) to get roughly [0, 1]
            if self.function_name == "dtlz2":
                # Use (1 + g_max) as normalization factor for consistency
                self.label_scale = 1.0 + self.g_max
            elif self.function_name == "dtlz1":
                # DTLZ1 can have much larger values due to multimodal g
                self.label_scale = torch.max(self.labels).item() * 1.1
            else:
                self.label_scale = 1.0

            self.labels = self.labels / self.label_scale
        else:
            self.label_scale = 1.0

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        x = self.designs[idx]
        y = self.labels[idx]

        if self.transform:
            x, y = self.transform(x, y)

        return x, y

    def get_forward_function(self, normalized: bool = True) -> Callable:
        """
        Get the forward function for round-trip evaluation.

        Args:
            normalized: If True, returns function that outputs normalized labels.

        Returns:
            Forward function mapping designs to labels.
        """
        if normalized and self.normalize_labels:
            return lambda x: self.forward_fn(x) / self.label_scale
        return self.forward_fn


# =============================================================================
# Utility Functions
# =============================================================================

def get_dtlz_config(
    num_design_params: int = 12,
    num_objectives: int = 3,
    function_name: str = "dtlz2",
) -> dict:
    """
    Get configuration dictionary for DTLZ experiments.

    This is useful for logging and reproducibility.

    Args:
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        function_name: DTLZ function name.

    Returns:
        Configuration dictionary.
    """
    return {
        "dataset": DTLZ_DATASET_NAME,
        "num_design_params": num_design_params,
        "num_objectives": num_objectives,
        "function_name": function_name,
        "design_dim": num_design_params,
        "label_dim": num_objectives,
    }


def create_dtlz_dataset_class(
    num_design_params: int = 12,
    num_objectives: int = 3,
    function_name: str = "dtlz2",
    normalize_labels: bool = True,
    sampling_strategy: str = "stratified",
    g_max: float = 2.0,
):
    """
    Create a configured DTLZDataset class for use with training functions.

    This factory function creates a dataset class with fixed hyperparameters,
    matching the interface expected by train_one_model_on_dataset().

    Args:
        num_design_params: Design space dimension P.
        num_objectives: Number of objectives L.
        function_name: DTLZ function name.
        normalize_labels: Whether to normalize labels.
        sampling_strategy: Sampling strategy ("uniform" or "stratified").
        g_max: Maximum g value for stratified sampling.

    Returns:
        A class that can be instantiated with (split, transform) arguments.

    Example:
        >>> DatasetClass = create_dtlz_dataset_class(num_design_params=50)
        >>> train_data = DatasetClass("train")
        >>> val_data = DatasetClass("val", transform=my_transform)
    """
    class ConfiguredDTLZDataset(DTLZDataset):
        def __init__(self, split, transform=None):
            super().__init__(
                split=split,
                num_design_params=num_design_params,
                num_objectives=num_objectives,
                function_name=function_name,
                transform=transform,
                normalize_labels=normalize_labels,
                sampling_strategy=sampling_strategy,
                g_max=g_max,
            )

    return ConfiguredDTLZDataset


# =============================================================================
# Main Entry Point for Testing
# =============================================================================

if __name__ == "__main__":
    import sys

    if len(sys.argv) >= 2 and sys.argv[1] == "test_dtlz2":
        print("Testing DTLZ2 forward function...")

        # Test with different dimensions
        for P in [6, 12, 50]:
            x = torch.rand(100, P)
            y = dtlz2_forward(x, num_objectives=3)
            print(f"  P={P}: x.shape={x.shape}, y.shape={y.shape}")
            print(f"    y range: [{y.min().item():.4f}, {y.max().item():.4f}]")

        # Verify Pareto front (when x_M = 0.5, g = 0)
        print("\nVerifying Pareto front (g=0):")
        x_pareto = torch.zeros(10, 12)
        x_pareto[:, :2] = torch.rand(10, 2)  # Random position params
        x_pareto[:, 2:] = 0.5  # Distance params at optimal
        y_pareto = dtlz2_forward(x_pareto, num_objectives=3)
        norms = torch.norm(y_pareto, dim=1)
        print(f"  Objective norms (should be ~1): {norms}")

    elif len(sys.argv) >= 2 and sys.argv[1] == "test_dataset":
        print("Testing DTLZDataset...")

        dataset = DTLZDataset(
            split="train",
            num_samples=1000,
            num_design_params=12,
            num_objectives=3,
            function_name="dtlz2",
        )

        print(f"  Dataset length: {len(dataset)}")
        x, y = dataset[0]
        print(f"  Sample shapes: x={x.shape}, y={y.shape}")
        print(f"  x range: [{dataset.designs.min():.4f}, {dataset.designs.max():.4f}]")
        print(f"  y range: [{dataset.labels.min():.4f}, {dataset.labels.max():.4f}]")

        # Test forward function
        forward_fn = dataset.get_forward_function(normalized=True)
        y_check = forward_fn(x.unsqueeze(0))
        print(f"  Round-trip check: {torch.allclose(y, y_check.squeeze())}")

    elif len(sys.argv) >= 2 and sys.argv[1] == "test_scaling":
        print("Testing DTLZ2 with various dimensions...")

        for P in [12, 24, 50, 100]:
            for L in [3, 5]:
                if P >= L:
                    dataset = DTLZDataset(
                        split="train",
                        num_samples=100,
                        num_design_params=P,
                        num_objectives=L,
                    )
                    x, y = dataset[0]
                    print(f"  P={P:3d}, L={L}: x.shape={x.shape}, y.shape={y.shape}")

    elif len(sys.argv) >= 2 and sys.argv[1] == "test_sampling":
        print("=" * 70)
        print("COMPARING UNIFORM vs STRATIFIED SAMPLING")
        print("=" * 70)

        for P in [12, 24, 50, 100]:
            L = 3
            n_samples = 10000

            print(f"\n{'='*70}")
            print(f"P = {P} design parameters, L = {L} objectives")
            print(f"{'='*70}")

            # Test uniform sampling
            print("\n--- Uniform Sampling ---")
            ds_uniform = DTLZDataset(
                split="train",
                num_samples=n_samples,
                num_design_params=P,
                num_objectives=L,
                sampling_strategy="uniform",
                normalize_labels=False,
            )

            # Compute g values for uniform sampling
            x_M_uniform = ds_uniform.designs[:, L-1:]
            g_uniform = torch.sum((x_M_uniform - 0.5) ** 2, dim=1)

            print(f"  g statistics:")
            print(f"    min:  {g_uniform.min().item():.4f}")
            print(f"    max:  {g_uniform.max().item():.4f}")
            print(f"    mean: {g_uniform.mean().item():.4f}")
            print(f"    std:  {g_uniform.std().item():.4f}")
            print(f"  % samples with g < 0.1: {100 * (g_uniform < 0.1).float().mean().item():.2f}%")
            print(f"  % samples with g < 0.5: {100 * (g_uniform < 0.5).float().mean().item():.2f}%")

            # Test stratified sampling
            print("\n--- Stratified Sampling ---")
            ds_stratified = DTLZDataset(
                split="train",
                num_samples=n_samples,
                num_design_params=P,
                num_objectives=L,
                sampling_strategy="stratified",
                g_max=2.0,
                normalize_labels=False,
            )

            # Compute g values for stratified sampling
            x_M_strat = ds_stratified.designs[:, L-1:]
            g_strat = torch.sum((x_M_strat - 0.5) ** 2, dim=1)

            print(f"  g statistics:")
            print(f"    min:  {g_strat.min().item():.4f}")
            print(f"    max:  {g_strat.max().item():.4f}")
            print(f"    mean: {g_strat.mean().item():.4f}")
            print(f"    std:  {g_strat.std().item():.4f}")
            print(f"  % samples with g < 0.1: {100 * (g_strat < 0.1).float().mean().item():.2f}%")
            print(f"  % samples with g < 0.5: {100 * (g_strat < 0.5).float().mean().item():.2f}%")
            print(f"  Clipping percentage: {ds_stratified.sampling_stats['clip_percentage']:.2f}%")

            # Compare label distributions
            print("\n--- Label Distribution Comparison ---")
            y_uniform = ds_uniform.labels
            y_strat = ds_stratified.labels
            print(f"  Uniform labels:    min={y_uniform.min().item():.3f}, max={y_uniform.max().item():.3f}, mean={y_uniform.mean().item():.3f}")
            print(f"  Stratified labels: min={y_strat.min().item():.3f}, max={y_strat.max().item():.3f}, mean={y_strat.mean().item():.3f}")

    elif len(sys.argv) >= 2 and sys.argv[1] == "test_clipping":
        print("=" * 70)
        print("CLIPPING ANALYSIS FOR STRATIFIED SAMPLING")
        print("=" * 70)

        for P in [12, 24, 50, 100, 200]:
            for g_max in [1.0, 2.0, 3.0]:
                n_samples = 100000
                L = 3

                ds = DTLZDataset(
                    split="train",
                    num_samples=n_samples,
                    num_design_params=P,
                    num_objectives=L,
                    sampling_strategy="stratified",
                    g_max=g_max,
                    normalize_labels=False,
                )

                stats = ds.sampling_stats
                print(f"P={P:3d}, g_max={g_max:.1f}: clip={stats['clip_percentage']:5.2f}%")

    else:
        print("Usage: python data_utils_dtlz.py [test_dtlz2|test_dataset|test_scaling|test_sampling|test_clipping]")
