import numpy as np
from typing import List
from ..mixture_gen import Component


def diagonal_multivariate_gaussian_generator(
    n_components: int,
    n_features: int,
    n_targets: int,
    rule_base_size: float = 0.5,
    rule_spacing_factor: float = 1.0,  # Spacing for X rules
    mean_spacing_factor: float = 2.0,  # Spacing for Y means (larger to reduce overlap initially)
    cov_scale: float = 0.5,  # Controls spread of each Gaussian component in Y space
    adaptive_sizing: bool = True,
) -> List[Component]:
    """
    Generate components with rules arranged diagonally in feature (X) space
    and multivariate Gaussian means arranged diagonally in target (Y) space.

    Args:
        n_components: Number of components.
        n_features: Number of informative features (for rules).
        n_targets: Dimensionality of the target variable Y.
        rule_base_size: Base size for hypercube rules in X space.
        rule_spacing_factor: Controls spacing between rule cubes in X space.
        mean_spacing_factor: Controls spacing between Gaussian means in Y space.
        cov_scale: Standard deviation for the diagonal covariance matrices.
                   Covariance matrix will be diag(cov_scale^2, ..., cov_scale^2).
        adaptive_sizing: Adjust rule_base_size based on n_components.

    Returns:
        List of Component objects configured for multivariate targets.
    """
    if n_targets <= 1:
        raise ValueError("This generator is for n_targets > 1.")

    components = []

    # --- Rule Placement (X-space) ---
    if adaptive_sizing:
        adapted_rule_size = rule_base_size * (1.0 - 0.05 * (n_components - 2))
        rule_size = max(0.1, min(0.9, adapted_rule_size))
    else:
        rule_size = min(max(0.1, rule_base_size), 0.9)

    rule_stride = rule_size * rule_spacing_factor
    rule_offset = -((n_components - 1) * rule_stride) / 2

    # --- Mean Placement (Y-space) ---
    # Use a reference scale (e.g., based on cov_scale) for spacing
    # A spacing factor of 2 means means are separated by 2*cov_scale along the diagonal
    mean_stride = cov_scale * mean_spacing_factor
    mean_offset = -((n_components - 1) * mean_stride) / 2

    # --- Generate Components ---
    for i in range(n_components):
        # --- Define Rules (X-space) ---
        rule_center = rule_offset + i * rule_stride
        rules = {}
        for j in range(n_features):
            half_size = rule_size / 2
            rules[j] = (rule_center - half_size, rule_center + half_size)

        # --- Define Distribution Parameters (Y-space) ---
        # Place means diagonally in the n_targets dimensional space
        mean_center = mean_offset + i * mean_stride
        mean_vector = np.ones(n_targets) * mean_center  # Simple diagonal placement

        # Define covariance matrix (simple diagonal)
        covariance_matrix = np.eye(n_targets) * (cov_scale**2)

        dist_params = {"mean": mean_vector, "cov": covariance_matrix}
        distribution_name = "multivariate_normal"

        # Equal weights for all components
        weight = 1.0 / n_components

        components.append(
            Component(
                rules=rules,
                distribution=distribution_name,
                dist_params=dist_params,
                weight=weight,
            )
        )

    return components
