from dataclasses import dataclass, field
from typing import Dict, Any, List, Callable, Optional, Union, Sequence
import numpy as np
from pathlib import Path
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split
from .mixture_gen import (
    Component,
    DatasetConfig,
    generate_mixture_data,
    _generate_y_value,
)
from .generators import (
    diagonal_multivariate_gaussian_generator,
    distributed_component_generator,
    diagonal_cube_generator,
    sparse_distributed_component_generator,
    tree_based_component_generator,
)
from scipy import stats
from copy import deepcopy
from dataclasses import asdict, is_dataclass
import json


@dataclass
class ParameterSpec:
    """Specification for a parameter that can be varied in experiments."""

    name: str
    dtype: type
    default: Any
    description: str = ""
    valid_values: Optional[List[Any]] = None

    def validate(self, value: Any) -> bool:
        """Check if a value is valid for this parameter."""
        if self.valid_values is not None:
            return value in self.valid_values
        try:
            self.dtype(value)
            return True
        except (ValueError, TypeError):
            return False


@dataclass
class GeneratorConfig:
    """Configuration for a specific component generator."""

    generator_func: Callable
    required_params: List[str]
    optional_params: Dict[str, Any] = field(default_factory=dict)
    supports_multidim: bool = True
    returns_tuple: bool = False  # True if returns (components, metadata)
    returns_empty_leaves: bool = False

    def get_params(self, all_params: Dict[str, Any]) -> Dict[str, Any]:
        """Extract relevant parameters for this generator."""
        params = {}

        # Add required parameters
        for param in self.required_params:
            if param in all_params:
                params[param] = all_params[param]

        # Add optional parameters that are present
        for param, default in self.optional_params.items():
            params[param] = all_params.get(param, default)

        return params


@dataclass
class ExperimentConfig:
    """Configuration for a controlled experiment."""

    name: str
    description: str

    # Base parameters (fixed across all datasets)
    base_params: Dict[str, Any]

    # Variable parameter and its values
    variable_param: str
    variable_values: Sequence[Union[int, float, str]]

    # Generator configuration
    generator_config: GeneratorConfig

    # Experiment settings
    n_replications: int = 3
    test_size: float = 0.2

    def validate(self) -> List[str]:
        """Validate the experiment configuration."""
        errors = []

        # Check required base parameters
        # required_base = ['n_samples', 'n_features', 'n_components']
        # for param in required_base:
        #     if param not in self.base_params:
        #         errors.append(f"Missing required base parameter: {param}")

        # Check generator requirements
        for param in self.generator_config.required_params:
            if param not in self.base_params and param != self.variable_param:
                errors.append(
                    f"Generator requires parameter '{param}' but it's not in base_params or variable_param"
                )

        return errors


class ComponentGeneratorRegistry:
    """Registry of available component generators with their configurations."""

    def __init__(self):
        self._generators = {}
        self._register_builtin_generators()

    def _register_builtin_generators(self):
        """Register built-in generators."""

        # Tree based generator
        self.register(
            "tree_based",
            GeneratorConfig(
                generator_func=tree_based_component_generator,
                required_params=["n_components", "n_features", "distributions"],
                optional_params={
                    "max_depth": 8,
                    "expansion_strategy": "breadth_first",
                    "domain_bounds": (0.0, 1.0),
                    "min_leaf_size_fraction": 0.1,
                    "random_seed": None,
                    "leaf_x_distribution": "uniform",
                    "gaussian_std_fraction": 0.25,
                    "empty_leaf_probability": 0.0,
                    "background_noise_fraction": 0.0,
                    "background_y_distribution": "uniform",
                    "background_y_uniform_bounds": (-2.0, 2.0),
                },
                supports_multidim=False,
                returns_empty_leaves=True,  # SETTING THE NEW FLAG
            ),
        )
        # Diagonal cube generator
        self.register(
            "diagonal_cube",
            GeneratorConfig(
                generator_func=diagonal_cube_generator,
                required_params=["n_components", "n_features", "distributions"],
                optional_params={},
                supports_multidim=False,
                returns_tuple=False,
            ),
        )

        # Distributed component generator
        self.register(
            "distributed",
            GeneratorConfig(
                generator_func=distributed_component_generator,
                required_params=["n_components", "n_features", "distributions"],
                optional_params={"placement_strategy": "random", "spacing_factor": 1.5},
                supports_multidim=False,
                returns_tuple=False,
            ),
        )

        # Sparse distributed generator
        self.register(
            "sparse_distributed",
            GeneratorConfig(
                generator_func=sparse_distributed_component_generator,
                required_params=["n_components", "n_features"],
                optional_params={
                    "distributions": ["normal", "uniform", "gamma", "exponential"],
                    "min_features_per_component": 2,
                    "max_features_per_component": None,
                    "placement_strategy": "random",
                    "adaptive_sizing": True,
                    "vary_size": False,
                    "vary_factor": 0.2,
                    "max_attempts": 1000,
                    "domain_expansion_factor": 0.2,
                    "max_domain_expansion_steps": 5,
                    "noise_feature_bounds": (-2.0, 2.0),
                    "auto_scale_noise": True,
                    "noise_scale_factor": 1.0,
                    "verbose": False,
                    "random_seed": 42,
                },
                supports_multidim=False,
                returns_tuple=True,
            ),
        )

        # Diagonal multivariate gaussian generator
        self.register(
            "diagonal_multivariate_gaussian",
            GeneratorConfig(
                generator_func=diagonal_multivariate_gaussian_generator,
                required_params=["n_components", "n_features", "n_targets"],
                optional_params={
                    "rule_base_size": 0.5,
                    "rule_spacing_factor": 1.0,
                    "mean_spacing_factor": 2.0,
                    "cov_scale": 0.5,
                },
                supports_multidim=True,
                returns_tuple=False,
            ),
        )

    def register(self, name: str, config: GeneratorConfig):
        """Register a new generator."""
        self._generators[name] = config

    def get(self, name: str) -> GeneratorConfig:
        """Get a generator configuration by name."""
        if name not in self._generators:
            raise ValueError(
                f"Unknown generator: {name}. Available: {list(self._generators.keys())}"
            )
        return self._generators[name]

    def list_generators(self) -> List[str]:
        """List all available generators."""
        return list(self._generators.keys())


class DatasetGenerator:
    """Handles the generation of individual datasets with specific parameter combinations."""

    def __init__(self, generator_registry: ComponentGeneratorRegistry):
        self.generator_registry = generator_registry

    def _adjust_conditional_densities(
        self, components: List[Component], overlap_alpha: float = 0.1
    ) -> List[Component]:
        """
        Adjust 1D conditional densities for controlled overlap.
        """
        if len(components) <= 1:
            return components
        adjusted_components = deepcopy(components)

        # This sort is important for predictable shifting
        adjusted_components.sort(
            key=lambda c: self._get_median(c.distribution, c.dist_params)
        )

        for i in range(1, len(adjusted_components)):
            prev_comp = adjusted_components[i - 1]
            curr_comp = adjusted_components[i]

            lower_percentile = 1.0 - overlap_alpha
            upper_percentile = overlap_alpha

            prev_upper = self._get_percentile(
                prev_comp.distribution, prev_comp.dist_params, lower_percentile
            )
            curr_lower = self._get_percentile(
                curr_comp.distribution, curr_comp.dist_params, upper_percentile
            )

            shift_amount = prev_upper - curr_lower
            adjusted_components[i].dist_params = self._shift_distribution(
                curr_comp.distribution, curr_comp.dist_params, shift_amount
            )

        return adjusted_components

    def _normalize_component_generator_call(
        self, generator_config: GeneratorConfig, **kwargs
    ) -> List[Component]:
        """Normalize different component generator interfaces."""
        result = generator_config.generator_func(**kwargs)

        if generator_config.returns_tuple:
            components, metadata = result
            return components
        else:
            return result

    def _get_percentile(self, dist_type: str, params: dict, p: float) -> float:
        """Calculate the percentile point for a distribution."""
        if dist_type.lower() == "normal":
            return float(
                stats.norm.ppf(
                    p, loc=params.get("loc", 0), scale=params.get("scale", 1)
                )
            )
        elif dist_type.lower() == "uniform":
            return float(
                stats.uniform.ppf(
                    p,
                    loc=params.get("low", 0),
                    scale=params.get("high", 1) - params.get("low", 0),
                )
            )
        elif dist_type.lower() == "gamma":
            return float(
                stats.gamma.ppf(
                    p,
                    a=params.get("shape", 1),
                    scale=params.get("scale", 1),
                    loc=params.get("loc", 0),
                )
            )
        elif dist_type.lower() == "exponential":
            return float(
                stats.expon.ppf(
                    p, scale=params.get("scale", 1), loc=params.get("loc", 0)
                )
            )
        else:
            raise ValueError(f"Unsupported distribution type: {dist_type}")

    def _shift_distribution(self, dist_type: str, params: dict, shift: float) -> dict:
        """Shift a distribution by modifying its parameters."""
        new_params = params.copy()

        if dist_type.lower() == "normal":
            new_params["loc"] = params.get("loc", 0) + shift
        elif dist_type.lower() == "uniform":
            width = params.get("high", 1) - params.get("low", 0)
            new_params["low"] = params.get("low", 0) + shift
            new_params["high"] = params.get("low", 0) + width + shift
        elif dist_type.lower() == "gamma" or dist_type.lower() == "exponential":
            new_params["loc"] = params.get("loc", 0) + shift

        return new_params

    def _get_median(self, dist_type: str, params: dict) -> float:
        """Calculate the median of a distribution."""
        return self._get_percentile(dist_type, params, 0.5)

    def _calculate_separability(
        self, X: np.ndarray, component_labels: np.ndarray
    ) -> float:
        """Calculate separability measure between components."""
        unique_components = np.unique(component_labels)
        n_components = len(unique_components)

        if n_components <= 1:
            return 0.0

        # Calculate centroids
        centroids = []
        for comp in unique_components:
            mask = component_labels == comp
            if np.sum(mask) > 0:
                centroids.append(np.mean(X[mask], axis=0))

        # Pairwise distances between centroids
        distances = []
        for i in range(len(centroids)):
            for j in range(i + 1, len(centroids)):
                distances.append(np.linalg.norm(centroids[i] - centroids[j]))

        # Within-component scatter
        within_scatter = 0
        for comp in unique_components:
            mask = component_labels == comp
            if np.sum(mask) > 1:
                comp_points = X[mask]
                centroid = np.mean(comp_points, axis=0)
                within_scatter += np.sum(np.linalg.norm(comp_points - centroid, axis=1))

        if within_scatter > 0:
            within_scatter /= len(X)
            return float(np.mean(distances) / within_scatter)
        else:
            return float(np.mean(distances) * 10)

    def _distribute_samples_across_components(
        self, total_samples: int, n_components: int, distribution_mode: str = "equal"
    ) -> List[int]:
        """Distribute total samples across components."""
        if distribution_mode == "equal":
            base_samples = total_samples // n_components
            remainder = total_samples % n_components
            samples_per_component = [base_samples] * n_components

            # Distribute remainder
            for i in range(remainder):
                samples_per_component[i] += 1

        elif distribution_mode == "random":
            # Random distribution with minimum of 10% per component
            min_samples = max(1, total_samples // (n_components * 10))
            remaining_samples = total_samples - (min_samples * n_components)

            samples_per_component = [min_samples] * n_components

            # Randomly distribute remaining samples
            np.random.seed(42)  # For reproducibility
            for _ in range(remaining_samples):
                idx = np.random.randint(0, n_components)
                samples_per_component[idx] += 1

        else:
            raise ValueError(f"Unknown distribution mode: {distribution_mode}")

        return samples_per_component

    def generate_dataset(
        self, params: Dict[str, Any], generator_name: str, name: str, seed: int
    ) -> Dict[str, Any]:
        """
        Generate a single dataset. This method now orchestrates the two-stage
        generation process for generators that define empty space.
        """
        np.random.seed(seed)
        rng = np.random.default_rng(seed)

        generator_config = self.generator_registry.get(generator_name)
        generator_params = generator_config.get_params(params)

        # --- Stage 1: Generate Component Definitions ---
        raw_components_result = generator_config.generator_func(**generator_params)

        if generator_config.returns_empty_leaves:
            active_components, metadata, empty_leaves = raw_components_result
        else:
            active_components = raw_components_result
            empty_leaves = []

        if not active_components and not (
            empty_leaves and params.get("background_noise_fraction", 0) > 0
        ):
            raise ValueError(
                "Component generator returned no active components and no background noise was requested."
            )

        # --- Stage 2: Generate Data Points ---
        total_samples = params["n_samples"]
        if params["use_n_samples_per_component"]:
            total_samples = total_samples * len(active_components)
        n_targets = params.get("n_targets", 1)
        bg_noise_fraction = (
            params.get("background_noise_fraction", 0.0) if empty_leaves else 0.0
        )

        # --- A: Generate data for ACTIVE components ---
        n_active = int(total_samples * (1.0 - bg_noise_fraction))
        if n_active > 0 and active_components:
            # Adjust Y distributions for overlap BEFORE generating data
            if n_targets == 1 and "distribution_overlap" in params:
                active_components = self._adjust_conditional_densities(
                    active_components, overlap_alpha=params["distribution_overlap"]
                )

            # Normalize weights of active components to sum to 1 for this subset of data
            total_active_weight = sum(c.weight for c in active_components)
            for c in active_components:
                c.weight /= total_active_weight

            X_active, y_active, labels_active = generate_mixture_data(
                n_samples=n_active,
                n_features=params["n_features"],
                components=active_components,
                n_targets=n_targets,
                n_noise_features=params.get("n_noise_features", 0),
                seed=rng.integers(1e6),
                noise_X=params.get("noise_X", 0.0),
                noise_Y=params.get("noise_Y", 0.0),
            )
        else:
            X_active, y_active, labels_active = np.array([]), np.array([]), np.array([])

        # --- B: Generate data for BACKGROUND noise ---
        n_background = total_samples - n_active
        if n_background > 0 and empty_leaves:
            total_features = params["n_features"] + params.get("n_noise_features", 0)
            X_bg = np.zeros((n_background, total_features))
            y_bg = np.zeros((n_background, n_targets))

            # Create a single dummy component to define the background Y distribution
            bg_y_dist = params.get("background_y_distribution", "uniform")
            bg_y_params = {}
            if bg_y_dist == "uniform":
                low, high = params.get("background_y_uniform_bounds", (-2.0, 2.0))
                bg_y_params = {"low": low, "high": high}
            bg_component = Component(
                rules={}, distribution=bg_y_dist, dist_params=bg_y_params, weight=1
            )

            # Randomly assign each background point to an empty leaf, then sample X
            chosen_leaves = rng.choice(empty_leaves, size=n_background, replace=True)
            for i, leaf in enumerate(chosen_leaves):
                for feat_idx, (low, high) in leaf.bounds.items():
                    X_bg[i, feat_idx] = rng.uniform(low, high)

                # Generate Y value
                y_val = _generate_y_value(bg_component, n_targets)
                if params.get("noise_Y", 0.0) > 0:
                    y_val += rng.normal(0, params["noise_Y"], size=n_targets)
                y_bg[i, :] = y_val

            # All background points get the same label, one higher than the last active component
            labels_bg = np.full(n_background, len(active_components), dtype=int)
        else:
            X_bg, y_bg, labels_bg = np.array([]), np.array([]), np.array([])

        # --- Stage 3: Combine and Finalize ---
        if n_active > 0 and n_background > 0:
            X = np.vstack((X_active, X_bg))
            y = np.vstack((y_active, y_bg))
            component_labels = np.concatenate((labels_active, labels_bg))
        elif n_active > 0:
            X, y, component_labels = X_active, y_active, labels_active
        elif n_background > 0:
            X, y, component_labels = X_bg, y_bg, labels_bg
        else:
            raise ValueError("No samples were generated.")

        # Shuffle the combined dataset
        shuffled_indices = rng.permutation(total_samples)
        X, y, component_labels = (
            X[shuffled_indices],
            y[shuffled_indices],
            component_labels[shuffled_indices],
        )

        # Create final dataset config and dictionary
        final_components = active_components
        if n_background > 0:
            # Add a conceptual component for the background noise for metadata purposes
            final_components.append(
                Component(
                    rules={},
                    distribution="background",
                    dist_params={},
                    weight=bg_noise_fraction,
                )
            )

        config = DatasetConfig(
            name=name,
            n_samples=total_samples,
            n_features=params["n_features"],
            n_noise_features=params.get("n_noise_features", 0),
            components=final_components,
            description=f"Generated dataset with {len(active_components)} active components and {bg_noise_fraction*100:.1f}% background noise.",
            n_targets=n_targets,
            distribution_overlap=params.get("distribution_overlap"),
        )

        # Calculate metrics
        separability_score = self._calculate_separability(X, component_labels)

        # Create train/test split
        test_size = params.get("test_size", 0.2)
        train_indices, test_indices = train_test_split(
            range(X.shape[0]), test_size=test_size, random_state=seed
        )

        new_config = DatasetConfig(
            name=config.name,
            n_samples=config.n_samples,
            n_features=config.n_features,
            n_noise_features=config.n_noise_features,
            components=config.components,
            description=config.description,
            n_targets=config.n_targets,
            distribution_overlap=config.distribution_overlap,
            train_indices=train_indices,
            test_indices=test_indices,
        )
        return {
            "X": X,
            "y": y,
            "component_labels": component_labels,
            "config": new_config,
            "seed": seed,
            "separability": separability_score,
            "train_indices": train_indices,
            "test_indices": test_indices,
            "generation_params": params.copy(),
        }


class EnhancedJSONEncoder(json.JSONEncoder):
    """
    A JSON encoder that can handle additional types:
    - dataclasses
    - numpy arrays and numeric types
    """

    def default(self, o):
        if is_dataclass(o):
            return asdict(o)
        if isinstance(o, np.ndarray):
            return o.tolist()
        if isinstance(o, (np.integer, np.floating)):
            return o.item()
        # Let the base class default method raise the TypeError
        return super().default(o)


class DataSuitGenerator:
    """Data generation utility"""

    def __init__(self, generator_registry: Optional[ComponentGeneratorRegistry] = None):
        self.generator_registry = generator_registry or ComponentGeneratorRegistry()
        self.dataset_generator = DatasetGenerator(self.generator_registry)

    def _save_dataset(self, dataset: Dict[str, Any], base_dir: Path) -> Path:
        """
        Saves a single dataset and its assets to a dedicated folder,
        including CSVs formatted for both ABDA and CDE pipelines.
        """
        dataset_name = dataset["config"].name
        dataset_dir = base_dir / dataset_name
        dataset_dir.mkdir(parents=True, exist_ok=True)

        # 1. Save the full dataset object as a pickle file for internal use
        with open(dataset_dir / "dataset.pkl", "wb") as f:
            pickle.dump(dataset, f)

        # 2. Save human-readable metadata as a JSON file
        metadata = {
            "config": dataset["config"],
            "generation_params": dataset["generation_params"],
            "seed": dataset["seed"],
            "separability": dataset["separability"],
        }
        with open(dataset_dir / "metadata.json", "w") as f:
            json.dump(metadata, f, cls=EnhancedJSONEncoder, indent=4)

        # --- Extract core data ---
        X, y = dataset["X"], dataset["y"]
        train_idx, test_idx = dataset["train_indices"], dataset["test_indices"]
        component_labels = dataset["component_labels"]

        # Ensure y is a 2D array for consistent processing
        if y.ndim == 1:
            y = y.reshape(-1, 1)

        # --- 3. Save data in format for ABDA evaluation pipeline ---
        # This pipeline uses combined (X,y) dataframes with headers
        full_dataset_np = np.concatenate((X, y), axis=1)
        columns = [f"feature_{i}" for i in range(X.shape[1])] + [
            f"target_{i}" for i in range(y.shape[1])
        ]
        df_full = pd.DataFrame(full_dataset_np, columns=columns)

        df_train = df_full.iloc[train_idx]
        df_test = df_full.iloc[test_idx]

        df_full.to_csv(dataset_dir / "full.csv", index=False)
        df_train.to_csv(dataset_dir / "train.csv", index=False)
        df_test.to_csv(dataset_dir / "test.csv", index=False)

        # Save true labels for the TEST set for the ABDA evaluation script
        true_labels_test = component_labels[test_idx]
        df_labels = pd.DataFrame(true_labels_test, columns=["true_label"])
        df_labels.to_csv(dataset_dir / "true_labels.csv", index=False)

        # --- 4. Save data in format for CDE evaluation pipeline ---
        # This pipeline expects separate X and y files without headers, and full index/label files.
        np.savetxt(dataset_dir / "X_train.csv", X[train_idx], delimiter=",")
        np.savetxt(dataset_dir / "y_train.csv", y[train_idx], delimiter=",")
        np.savetxt(dataset_dir / "X_test.csv", X[test_idx], delimiter=",")
        np.savetxt(dataset_dir / "y_test.csv", y[test_idx], delimiter=",")

        np.savetxt(
            dataset_dir / "component_labels.csv",
            component_labels,
            delimiter=",",
            fmt="%i",
        )
        np.savetxt(
            dataset_dir / "train_indices.csv", train_idx, delimiter=",", fmt="%i"
        )
        np.savetxt(dataset_dir / "test_indices.csv", test_idx, delimiter=",", fmt="%i")

        return dataset_dir

    def run_experiment(
        self, experiment: ExperimentConfig, output_dir: str, base_seed: int = 42
    ) -> List[Dict[str, Any]]:
        """Run a single controlled experiment, saving each dataset in its own folder."""
        errors = experiment.validate()
        if errors:
            raise ValueError(f"Experiment validation failed: {errors}")

        exp_dir = Path(output_dir) / experiment.name
        exp_dir.mkdir(parents=True, exist_ok=True)

        dataset_summary = []
        print(f"Running experiment: {experiment.name}")
        print(
            f"Variable parameter: '{experiment.variable_param}' with values {experiment.variable_values}"
        )

        generator_name = next(
            (
                name
                for name, cfg in self.generator_registry._generators.items()
                if cfg.generator_func == experiment.generator_config.generator_func
            ),
            None,
        )
        if generator_name is None:
            raise ValueError("Generator function not found in registry")

        for param_value in experiment.variable_values:
            for rep in range(experiment.n_replications):
                config_seed = (
                    base_seed + hash((experiment.name, str(param_value), rep))
                ) % (2**32)

                params = experiment.base_params.copy()
                params[experiment.variable_param] = param_value
                params["test_size"] = experiment.test_size

                dataset_name = f"{experiment.name}_{experiment.variable_param}-{param_value}_rep-{rep}"

                try:
                    dataset = self.dataset_generator.generate_dataset(
                        params=params,
                        generator_name=generator_name,
                        name=dataset_name,
                        seed=config_seed,
                    )

                    # Save the dataset using the new structured method
                    dataset_path = self._save_dataset(dataset, exp_dir)

                    # Record summary, pointing to the new directory
                    summary_entry = {
                        "experiment": experiment.name,
                        "variable_param": experiment.variable_param,
                        "param_value": param_value,
                        "replication": rep,
                        "n_samples": dataset["config"].n_samples,
                        "n_features": dataset["config"].n_features,
                        "n_targets": dataset["config"].n_targets,
                        "n_components": len(dataset["config"].components),
                        "separability": dataset["separability"],
                        "path": str(dataset_path),  # This now points to the directory
                        "seed": config_seed,
                    }
                    summary_entry.update(
                        {
                            f"param_{k}": v
                            for k, v in dataset["generation_params"].items()
                        }
                    )
                    dataset_summary.append(summary_entry)

                    print(f"  Generated and saved dataset to: {dataset_path}")

                except Exception as e:
                    print(f"Error generating dataset {dataset_name}: {e}")
                    import traceback

                    traceback.print_exc()
        return dataset_summary

    def run_experiments(
        self, experiments: List[ExperimentConfig], output_dir: str, base_seed: int = 42
    ) -> pd.DataFrame:
        """Run multiple experiments and return a combined summary."""
        all_summaries = []
        for experiment in experiments:
            try:
                summary = self.run_experiment(experiment, output_dir, base_seed)
                all_summaries.extend(summary)
            except Exception as e:
                print(f"Failed to run experiment {experiment.name}: {e}")

        if not all_summaries:
            print("No datasets were generated.")
            return pd.DataFrame()

        summary_df = pd.DataFrame(all_summaries)
        output_path = Path(output_dir)
        summary_path = output_path / "experiment_summary.csv"
        summary_df.to_csv(summary_path, index=False)
        print(f"\nExperiment summary saved to {summary_path}")
        return summary_df


def create_component_scaling_experiment(
    name: str,
    generator_name: str = "sparse_distributed",
    n_samples: int = 1000,
    n_features: int = 3,
    n_targets: int = 1,
    component_values: Optional[List[int]] = None,
    distributions: Optional[List[str]] = None,
    n_replications: int = 5,
    **kwargs,
) -> ExperimentConfig:
    """Create an experiment that varies the number of components."""

    if component_values is None:
        component_values = [2, 3, 4, 5, 7, 10]

    if distributions is None:
        distributions = ["normal", "uniform", "gamma", "exponential"]

    registry = ComponentGeneratorRegistry()
    generator_config = registry.get(generator_name)

    base_params = {
        "n_samples": n_samples,
        "n_features": n_features,
        "n_targets": n_targets,
        "distributions": distributions,
        "distribution_overlap": 0.1,
        "sample_distribution_mode": "equal",
        "use_n_samples_per_component": True,
        **kwargs,
    }

    return ExperimentConfig(
        name=name,
        description=f"Impact of increasing number of components using {generator_name} generator",
        base_params=base_params,
        variable_param="n_components",
        variable_values=list(component_values),
        generator_config=generator_config,
        n_replications=n_replications,
    )


def create_noise_features_experiment(
    name: str,
    generator_name: str = "sparse_distributed",
    n_samples: int = 1000,
    n_features: int = 5,
    n_components: int = 3,
    noise_values: Optional[List[int]] = None,
    distributions: Optional[List[str]] = None,
    n_replications: int = 5,
    **kwargs,
) -> ExperimentConfig:
    """Create an experiment that varies the number of noise features."""

    if noise_values is None:
        noise_values = [0, 2, 5, 10, 15, 20]

    if distributions is None:
        distributions = ["normal", "uniform", "gamma", "exponential"]

    registry = ComponentGeneratorRegistry()
    generator_config = registry.get(generator_name)

    base_params = {
        "n_samples": n_samples,
        "n_features": n_features,
        "n_components": n_components,
        "distributions": distributions,
        "distribution_overlap": 0.1,
        **kwargs,
    }

    return ExperimentConfig(
        name=name,
        description=f"Impact of increasing noise features using {generator_name} generator",
        base_params=base_params,
        variable_param="n_noise_features",
        variable_values=list(noise_values),
        generator_config=generator_config,
        n_replications=n_replications,
    )


def create_overlap_experiment(
    name: str,
    generator_name: str = "sparse_distributed",
    n_samples: int = 1000,
    n_features: int = 3,
    n_components: int = 4,
    overlap_values: Optional[List[float]] = None,
    distributions: Optional[List[str]] = None,
    n_replications: int = 5,
    **kwargs,
) -> ExperimentConfig:
    """Create an experiment that varies distribution overlap."""

    if overlap_values is None:
        overlap_values = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4]

    if distributions is None:
        distributions = ["normal", "uniform", "gamma", "exponential"]

    registry = ComponentGeneratorRegistry()
    generator_config = registry.get(generator_name)

    base_params = {
        "n_samples": n_samples,
        "n_features": n_features,
        "n_components": n_components,
        "distributions": distributions,
        **kwargs,
    }

    return ExperimentConfig(
        name=name,
        description=f"Impact of varying distribution overlap using {generator_name} generator",
        base_params=base_params,
        variable_param="distribution_overlap",
        variable_values=list(overlap_values),
        generator_config=generator_config,
        n_replications=n_replications,
    )


def create_samples_experiment(
    name: str,
    generator_name: str = "sparse_distributed",
    n_features: int = 3,
    n_components: int = 4,
    sample_values: Optional[List[float]] = None,
    distributions: Optional[List[str]] = None,
    n_replications: int = 5,
    **kwargs,
) -> ExperimentConfig:
    """Create an experiment that varies distribution overlap."""

    if sample_values is None:
        sample_values = [1000, 500, 200]

    if distributions is None:
        distributions = ["normal", "uniform", "gamma", "exponential"]

    registry = ComponentGeneratorRegistry()
    generator_config = registry.get(generator_name)

    base_params = {
        "n_features": n_features,
        "n_components": n_components,
        "distributions": distributions,
        "distribution_overlap": kwargs.get("distribution_overlap", 0.1),
        "use_n_samples_per_component": True,
        **kwargs,
    }

    return ExperimentConfig(
        name=name,
        description=f"Impact of varying distribution overlap using {generator_name} generator",
        base_params=base_params,
        variable_param="n_samples",
        variable_values=list(sample_values),
        generator_config=generator_config,
        n_replications=n_replications,
    )


def create_distnoiseY_experiment(
    name: str,
    generator_name: str = "sparse_distributed",
    n_features: int = 3,
    n_components: int = 4,
    noise_std_values: Optional[List[float]] = None,
    distributions: Optional[List[str]] = None,
    n_replications: int = 5,
    **kwargs,
) -> ExperimentConfig:
    """Create an experiment that varies noise added to Y."""

    if noise_std_values is None:
        noise_std_values = [0, 0.05, 0.1, 0.3, 0.7, 1.0]

    if distributions is None:
        distributions = ["normal", "uniform", "gamma", "exponential"]

    registry = ComponentGeneratorRegistry()
    generator_config = registry.get(generator_name)

    base_params = {
        "n_features": n_features,
        "n_components": n_components,
        "n_samples": kwargs.get("n_samples", 500),
        "n_targets": kwargs.get("n_targets", 1),
        "distribution_overlap": kwargs.get("distribution_overlap", 0.1),
        "distributions": distributions,
        **kwargs,
    }

    return ExperimentConfig(
        name=name,
        description=f"Experiment varying Y noise levels for {generator_name}",
        base_params=base_params,
        variable_param="noise_Y",
        variable_values=list(noise_std_values),
        generator_config=generator_config,
        n_replications=n_replications,
    )


def create_distnoiseX_experiment(
    name: str,
    generator_name: str = "sparse_distributed",
    n_features: int = 3,
    n_components: int = 4,
    noise_std_values: Optional[List[float]] = None,
    distributions: Optional[List[str]] = None,
    n_replications: int = 5,
    **kwargs,
) -> ExperimentConfig:
    """Create an experiment that varies noise added to X."""

    if noise_std_values is None:
        noise_std_values = [0, 0.05, 0.1, 0.2, 0.5]

    if distributions is None:
        distributions = ["normal", "uniform", "gamma", "exponential"]

    registry = ComponentGeneratorRegistry()
    generator_config = registry.get(generator_name)

    base_params = {
        "n_features": n_features,
        "n_components": n_components,
        "n_samples": kwargs.get("n_samples", 500),
        "n_targets": kwargs.get("n_targets", 1),
        "distributions": distributions,
        "distribution_overlap": kwargs.get("distribution_overlap", 0.1),
        **kwargs,
    }

    return ExperimentConfig(
        name=name,
        description=f"Experiment varying X noise levels for {generator_name}",
        base_params=base_params,
        variable_param="noise_X",
        variable_values=list(noise_std_values),
        generator_config=generator_config,
        n_replications=n_replications,
    )
