import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..mixture_model_base import MixtureModel


def _reconstruct_model_from_snapshot(
    original_model: "MixtureModel", snapshot
) -> "MixtureModel":
    """Helper to create and load a model from a training snapshot."""
    temp_model = original_model.__class__(original_model.config)

    cpu_device = torch.device("cpu")
    temp_model.device = cpu_device
    temp_model.X_scaled = original_model.X_scaled.to(cpu_device)
    temp_model.Y_scaled = original_model.Y_scaled.to(cpu_device)

    temp_model.preprocessor = original_model.preprocessor

    # For RemixMixtureModel, the gmm_model is needed before expert initialization
    if hasattr(original_model, "gmm_model"):
        temp_model.gmm_model = original_model.gmm_model

    temp_model._setup_rules_model()
    temp_model._initialize_experts()

    # Load the states from the snapshot
    temp_model.rules_model.load_state_dict(snapshot.mixture_rules_state)
    temp_model.expert_model.load_state_dict(snapshot.expert_model_state)

    # Manually set the disabled status for each rule
    for i, rule in enumerate(temp_model.rules_model.rules):
        if i < len(snapshot.disabled_components):
            rule.disabled = snapshot.disabled_components[i]

    temp_model.rules_model.eval()
    temp_model.expert_model.eval()

    return temp_model


def _get_component_colors_and_labels(model: "MixtureModel"):
    """Determines a consistent color and label mapping for components."""
    with torch.no_grad():
        responsibilities, _ = model.rules_model(model.X_scaled)
        component_importance = responsibilities.sum(dim=0).cpu().numpy()

    n_interpretable_rules = model.config.n_mixture_components
    use_background_comp = model.config.use_background_component
    n_total_components = (
        n_interpretable_rules + 1 if use_background_comp else n_interpretable_rules
    )

    last_disabled_flags = model.rules_model.get_disabled_rules()
    active_indices = [
        i
        for i, disabled in enumerate(last_disabled_flags)
        if not disabled and component_importance[i] > 1e-6
    ]
    sorted_active_indices = sorted(
        active_indices, key=lambda i: component_importance[i], reverse=True
    )

    colors = np.full((n_total_components, 4), [0.7, 0.7, 0.7, 0.5])  # Default grey
    labels = [f"Rule {i+1}" for i in range(n_interpretable_rules)]

    if sorted_active_indices:
        palette = plt.cm.turbo(np.linspace(0, 1, len(sorted_active_indices)))
        for new_idx, old_idx in enumerate(sorted_active_indices):
            colors[old_idx] = palette[new_idx]

    if use_background_comp:
        colors[-1] = [0.3, 0.3, 0.3, 0.6]
        labels.append("Sink")

    return colors, labels


def prepare_probs(mixture_probs, hard_assignment):
    """Converts soft probabilities to one-hot assignments if requested."""
    if hard_assignment:
        if mixture_probs.shape[1] == 0:
            return mixture_probs
        assignments = np.argmax(mixture_probs, axis=1)
        return np.eye(mixture_probs.shape[1])[assignments]
    return mixture_probs
