from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Union, NamedTuple
import numpy as np
from pathlib import Path
import pickle
from scipy.optimize import linear_sum_assignment


@dataclass
class Component:
    """Defines a single component in the mixture model"""

    rules: Dict[int, Tuple[float, float]]
    distribution: str
    dist_params: Dict[str, Union[float, np.ndarray]]
    weight: float


class DatasetConfig(NamedTuple):
    """Configuration for a test dataset"""

    name: str
    n_samples: int
    n_features: int
    n_noise_features: int
    components: List[Component]
    description: str
    n_targets: int = 1
    distribution_overlap: float | None = 0.2
    train_indices: Optional[Union[np.ndarray, List]] = None
    test_indices: Optional[Union[np.ndarray, List]] = None


def _check_rules(x: np.ndarray, rules: Dict[int, Tuple[float, float]]) -> bool:
    """Check if a sample point satisfies all rules of a component"""
    for feat_idx, (lower, upper) in rules.items():
        if not (lower <= x[feat_idx] <= upper):
            return False
    return True


def _generate_y_value(component: Component, n_targets: int) -> np.ndarray:
    """Generate a single y value (scalar or vector) based on component distribution."""
    dist = component.distribution.lower()
    params = component.dist_params

    if n_targets == 1:
        if dist == "normal":
            val = np.random.normal(loc=params["loc"], scale=params["scale"])
        elif dist == "uniform":
            val = np.random.uniform(low=params["low"], high=params["high"])
        elif dist == "exponential":
            val = np.random.exponential(scale=params["scale"]) + params.get("loc", 0)
        elif dist == "gamma":
            val = np.random.gamma(
                shape=params["shape"], scale=params["scale"]
            ) + params.get("loc", 0)
        elif dist == "noise_only":
            val = 0.0  # Base value is 0; global noise is added later
        else:
            raise ValueError(f"Unsupported 1D distribution: {component.distribution}")
        return np.array([val])  # Return as 1D array
    else:
        # --- Multidimensional Case ---
        if dist == "multivariate_normal":
            mean = params["mean"]
            cov = params["cov"]
            if not isinstance(mean, np.ndarray) or not isinstance(cov, np.ndarray):
                raise TypeError(
                    f"For multivariate_normal, mean and cov must be numpy arrays. Got {type(mean)}, {type(cov)}"
                )
            if mean.shape != (n_targets,) or cov.shape != (n_targets, n_targets):
                raise ValueError(
                    f"Mean/Cov shape mismatch. Expected mean ({n_targets},), cov ({n_targets},{n_targets}). Got {mean.shape}, {cov.shape}"
                )
            mean = mean.flatten()
            return np.random.multivariate_normal(mean=mean, cov=cov)
        elif dist == "noise_only":
            return np.zeros(n_targets)  # Vector of zeros for MD noise
        else:
            raise ValueError(
                f"Unsupported {n_targets}-D distribution: {component.distribution}. Expected 'multivariate_normal'."
            )


def generate_mixture_data(
    n_samples: int,
    n_features: int,
    components: List[Component],
    n_targets: int = 1,
    n_noise_features: int = 0,
    seed: Optional[int] = None,
    feature_ranges: Optional[List[Tuple[float, float]]] = None,
    noise_X: float = 0.0,
    noise_Y: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generate synthetic data from a mixture of components (1D or MD target).
    (No changes to this function's main body)
    """
    if seed is not None:
        np.random.seed(seed)

    total_weight = sum(c.weight for c in components)
    if not np.isclose(total_weight, 1.0):
        weights = np.array([c.weight for c in components]) / total_weight
    else:
        weights = [c.weight for c in components]

    if feature_ranges is None:
        feature_ranges = [(-1, 1)] * n_features

    total_features = n_features + n_noise_features
    X = np.zeros((n_samples, total_features))

    for i in range(n_features):
        low, high = feature_ranges[i]
        X[:, i] = np.random.uniform(low, high, n_samples)

    if n_noise_features > 0:
        noise_low, noise_high = (-1, 1)
        X[:, n_features:] = np.random.uniform(
            noise_low, noise_high, (n_samples, n_noise_features)
        )

    y = np.zeros((n_samples, n_targets))
    component_indices = np.arange(len(components))
    component_assignments = np.random.choice(
        component_indices, size=n_samples, p=weights
    )
    component_labels = np.zeros(n_samples, dtype=int)

    for i in range(n_samples):
        assigned_comp_idx = component_assignments[i]
        component = components[assigned_comp_idx]
        component_labels[i] = assigned_comp_idx

        attempts = 0
        max_attempts = 1000
        while (
            not _check_rules(X[i, :n_features], component.rules)
            and attempts < max_attempts
        ):
            for feat_idx in component.rules:
                if feat_idx < n_features:
                    low, high = component.rules[feat_idx]
                    X[i, feat_idx] = np.random.uniform(low, high)
            attempts += 1

        if attempts == max_attempts:
            print(
                f"Warning: Max attempts reached for satisfying rules for component {assigned_comp_idx} at sample {i}."
            )

        y_val = _generate_y_value(component, n_targets)

        if noise_Y > 0:
            y_noise = np.random.normal(0, noise_Y, size=n_targets)
            y_val += y_noise

        y[i, :] = y_val

    if noise_X > 0:
        X += np.random.normal(0, noise_X, size=X.shape)

    return X, y, component_labels


def load_test_dataset(
    dataset_name: str, data_dir: str
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, DatasetConfig]:
    """
    Load a test dataset by name

    Returns:
        Tuple of (X, y, component_labels, config) where:
        - X: feature matrix
        - y: target values
        - component_labels: array indicating which component generated each sample
        - config: full dataset configuration
    """
    data_path = Path(data_dir) / f"{dataset_name}.pkl"

    if not data_path.exists():
        data_path = Path(data_dir) / f"{dataset_name}/dataset.pkl"
    if not data_path.exists():
        raise ValueError(f"Dataset {dataset_name} not found in {data_dir}")

    with open(data_path, "rb") as f:
        dataset = pickle.load(f)

    return dataset["X"], dataset["y"], dataset["component_labels"], dataset["config"]


def extract_feature_rules(
    components: List[Component],
) -> Dict[int, List[Tuple[float, float, float]]]:
    """
    Extract rules for each feature from components.

    Returns:
        Dictionary mapping feature index to list of (lower_bound, upper_bound, target_mean)
        for each component that has rules for that feature.
    """
    feature_rules = {}

    for component in components:
        for feat_idx, (lower, upper) in component.rules.items():
            # Get the expected target value for this component
            if component.distribution == "normal":
                target_mean = component.dist_params["loc"]
            elif component.distribution == "uniform":
                target_mean = (
                    component.dist_params["low"] + component.dist_params["high"]
                ) / 2
            else:
                raise ValueError(f"Unsupported distribution: {component.distribution}")

            if feat_idx not in feature_rules:
                feature_rules[feat_idx] = []
            feature_rules[feat_idx].append((lower, upper, target_mean))

    return feature_rules


def compare_component_assignments(
    true_labels: np.ndarray, pred_labels: np.ndarray
) -> Tuple[float, dict]:
    """
    Compare predicted component assignments with ground truth,
    accounting for permutation of components.

    Handles cases where:
    - Number of true and predicted components differ
    - Some predicted components may have no samples

    Args:
        true_labels: Ground truth component labels
        pred_labels: Predicted component labels

    Returns:
        Tuple of (accuracy, optimal_mapping) where:
        - accuracy is the fraction of correctly assigned samples
        - optimal_mapping maps predicted labels to true labels
    """
    # Get actually used components (ignore empty ones)
    true_labels = true_labels.flatten().astype(int)
    pred_labels = pred_labels.flatten().astype(int)

    unique_true = np.unique(true_labels)
    unique_pred = np.unique(pred_labels)

    # Create cost matrix only for components that exist in the data
    cost_matrix = np.zeros((len(unique_true), len(unique_pred)))

    # Fill cost matrix with negative counts
    for i, true_label in enumerate(unique_true):
        for j, pred_label in enumerate(unique_pred):
            matches = np.sum((true_labels == true_label) & (pred_labels == pred_label))
            cost_matrix[i, j] = -matches

    # Find optimal assignment
    true_idx, pred_idx = linear_sum_assignment(cost_matrix)

    # Create mapping using actual component labels
    label_mapping = dict(zip(unique_pred[pred_idx], unique_true[true_idx]))

    # Map predicted labels, using -1 for any unseen predicted labels
    mapped_pred_labels = np.array(
        [label_mapping.get(label, -1) for label in pred_labels]
    )

    # Calculate accuracy
    accuracy = np.mean(mapped_pred_labels == true_labels)

    return accuracy, label_mapping
