import matplotlib.pyplot as plt
import torch
import numpy as np
import matplotlib.patches as patches
import seaborn as sns
from typing import TYPE_CHECKING, List, Literal
from .base import BasePlotter
from .utils import _reconstruct_model_from_snapshot, _get_component_colors_and_labels
from ..mixture_utils import get_log_prob
from . import rules as rules_extractor
from . import plotting_rules
from .utils import prepare_probs


if TYPE_CHECKING:
    from ..mixture_model_base import MixtureModel
    from ..neural_rules import GMMRemixer


class ModelPlotter(BasePlotter):
    """
    A helper class to create consistent, publication-quality visualizations
    for final mixture model results, focusing on subgroup and population distributions.
    """

    def plot_expert_densities(
        self, model: "MixtureModel", n_points=200, save_name=None, **kwargs
    ):
        """
        Public method to plot the density of each expert component.
        """
        from ..neural_rules import GMMRemixer
        from ..flow_experts import FlowMixtureExperts

        if isinstance(model.expert_model, GMMRemixer):
            if not hasattr(model, "gmm_model") or model.gmm_model is None:
                raise AttributeError(
                    "The provided RemixMixtureModel is missing the 'gmm_model' attribute."
                )
            self._plot_gmm_densities(
                Y=model.Y_original,
                gmm_remixer=model.expert_model,
                base_gmm=model.gmm_model,
                scaler_y=model.preprocessor.scaler_y,
                disabled_components=model.disabled_components,
                n_points=n_points,
                save_name=save_name,
                **kwargs,
            )
        elif isinstance(model.expert_model, FlowMixtureExperts):
            self._plot_flow_densities(
                Y=model.Y_original,
                component_flows=model.expert_model.component_flows,
                scaler_y=model.preprocessor.scaler_y,
                disabled_components=model.disabled_components,
                n_points=n_points,
                save_name=save_name,
                **kwargs,
            )
        else:
            expert_type = type(model.expert_model).__name__
            raise NotImplementedError(
                f"Density plotting is not implemented for expert type: {expert_type}"
            )

    def plot_rules_summary(
        self,
        model: "MixtureModel",
        output_format="text",
        filepath=None,
        responsibility_threshold=0.1,
        activation_threshold=0.01,
        weight_threshold=0.1,
        assign_max_resp=False,
        rules_to_plot: List[int] = None,
        show_proportional_dist: bool = True,
        show_density_dist: bool = False,
        show_population_histogram: bool = False,
        scale_densities_by_weight: bool = False,
        **kwargs,
    ):
        """
        Extracts and displays rules in a structured, comparable format.

        This is the new, recommended method for inspecting rules. It can output
        to a text table, an HTML file, or various plot formats.
        """
        if (
            output_format in ["html", "plot", "plot_condensed", "plot_combined"]
            and not filepath
        ):
            raise ValueError(
                f"A filepath must be provided when output_format is '{output_format}'."
            )

        structured_rules, X_orig, Y_orig_for_plot = (
            rules_extractor.extract_rules_structured(
                rules_model=model.rules_model,
                X_scaled=model.X_scaled,
                Y_data=torch.from_numpy(model.Y_original).to(model.device),
                feature_names=model.feature_names,
                scaler_x=model.preprocessor.scaler_x,
                responsibility_threshold=responsibility_threshold,
                activation_threshold=activation_threshold,
                assign_max_resp=assign_max_resp,
            )
        )

        if output_format == "html":
            html_out = rules_extractor.format_rules_as_html(
                structured_rules, weight_threshold
            )
            full_save_path = self.output_dir / filepath
            with open(full_save_path, "w", encoding="utf-8") as f:
                f.write(html_out)
            return f"HTML rule table saved to {full_save_path}"

        elif output_format == "plot":
            plotting_rules.plot_standard(
                parent_plotter=self,
                structured_rules=structured_rules,
                model=model,
                X_orig=X_orig,
                Y_orig=Y_orig_for_plot,
                filepath=filepath,
                weight_threshold=weight_threshold,
                show_densities=show_density_dist,  # Note: show_densities is an old arg
                **kwargs,
            )
            return "Rule plot generated"

        elif output_format == "plot_combined":
            plotting_rules.plot_combined(
                parent_plotter=self,
                structured_rules=structured_rules,
                model=model,
                X_orig=X_orig,
                Y_orig=Y_orig_for_plot,
                filepath=filepath,
                rules_to_plot=rules_to_plot,
                show_proportional_dist=show_proportional_dist,
                show_density_dist=show_density_dist,
                weight_threshold=weight_threshold,
                show_population_histogram=show_population_histogram,
                scale_densities_by_weight=scale_densities_by_weight,
                **kwargs,
            )
            return "Combined rule plot generated"

        elif output_format == "plot_condensed":
            plotting_rules.plot_condensed(
                parent_plotter=self,
                structured_rules=structured_rules,
                model=model,
                Y_orig=Y_orig_for_plot,
                filepath=filepath,
                rules_to_plot=rules_to_plot,
                use_density=show_density_dist,
                **kwargs,
            )
            return "Condensed rule plot generated"

        else:  # Default to text
            return rules_extractor.format_rules_as_text_table(
                structured_rules, weight_threshold
            )

    def data_histogram(self, Y, bins=50, save_name=None):
        if Y.ndim > 1:
            raise ValueError("Y should be a 1D array.")
        plt.figure(figsize=(8, 6))
        sns.histplot(Y, bins=bins)
        plt.title("Distribution of Target Variable")
        plt.xlabel(self.name_mapping.get("y_variable", "Y"))
        plt.ylabel("Density")
        plt.grid(True, linestyle="--", alpha=0.6)
        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def data_scatter(self, X, Y, save_name=None, labels=None, **kwargs):
        plt.figure(figsize=(8, 8))
        plt.scatter(X, Y, alpha=kwargs.get("alpha", 0.5), s=kwargs.get("s", 10))
        plt.xlabel(labels[0] if labels else "Feature 1")
        plt.ylabel(labels[1] if labels else self.name_mapping.get("y_variable", "Y"))
        plt.grid(True, linestyle="--", alpha=0.6)
        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def data_histogram_2d(self, Y, bins=50, save_name=None, labels=None):
        if Y.ndim != 2 or Y.shape[1] != 2:
            raise ValueError("Y should be a 2D array with shape [n_samples, 2].")
        plt.figure(figsize=(8, 6))
        sns.histplot(x=Y[:, 0], y=Y[:, 1], bins=bins)
        plt.title("2D Histogram of Target Variable")
        plt.xlabel(labels[0] if labels else "Y Dimension 1")
        plt.ylabel(labels[1] if labels else "Y Dimension 2")
        plt.grid(True, linestyle="--", alpha=0.6)
        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def plot_soft_histograms(
        self,
        Y,
        mixture_probs,
        n_bins=50,
        title="Component Distributions",
        save_name=None,
        show_stacked=True,
        hard_assignment=False,
        density=False,
        disabled_components=None,
        show_population_ref=False,
        y_name="Y",
    ):
        if Y.shape[0] != mixture_probs.shape[0]:
            raise ValueError(
                "Y and mixture_probs must have the same number of samples."
            )

        Y_flat = Y.ravel()
        n_total_components = mixture_probs.shape[1]
        processed_probs = prepare_probs(mixture_probs, hard_assignment)
        assign_suffix = "(Hard)" if hard_assignment else "(Soft)"
        norm_suffix = "Density" if density else "Count"
        full_title = f"{title} {assign_suffix}"
        y_label = f"{norm_suffix}"

        active_indices = range(n_total_components)
        if disabled_components is not None:
            active_indices = [
                i for i, disabled in enumerate(disabled_components) if not disabled
            ]
        if not active_indices:
            print("No active components to plot.")
            return

        n_active_components = len(active_indices)
        colors = self._get_colors(n_active_components)

        if show_stacked or show_population_ref:
            fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
            ax1, ax2 = axes
        else:
            fig, ax2 = plt.subplots(1, 1, figsize=(8, 6))
            ax1 = None

        bins = np.linspace(Y_flat.min(), Y_flat.max(), n_bins + 1)
        hist_data = []
        for i in active_indices:
            heights, _ = np.histogram(
                Y_flat, bins=bins, weights=processed_probs[:, i], density=density
            )
            hist_data.append(heights)
        hist_data = np.array(hist_data)

        if ax1:
            if show_population_ref:
                ax1.hist(
                    Y_flat, bins=bins, density=density, alpha=0.8, label="Population"
                )
                ax1.set_xlabel(self.name_mapping.get("y_variable", y_name))
                ax1.set_ylabel(y_label)
                ax1.legend()
                ax1.grid(True, linestyle="--", alpha=0.0)
            elif show_stacked:
                bottom = np.zeros_like(hist_data[0])
                for i, comp_idx in enumerate(active_indices):
                    ax1.bar(
                        bins[:-1],
                        hist_data[i],
                        width=np.diff(bins),
                        bottom=bottom,
                        label=f"Component {comp_idx+1}",
                        color=colors[i],
                        alpha=0.7,
                    )
                    bottom += hist_data[i]
                ax1.set_title("Stacked View")
                ax1.set_xlabel(self.name_mapping.get("y_variable", y_name))
                ax1.set_ylabel(y_label)
                ax1.legend()
                ax1.grid(True, linestyle="--", alpha=0.0)

        for i, comp_idx in enumerate(active_indices):
            ax2.bar(
                bins[:-1],
                hist_data[i],
                width=np.diff(bins),
                label=f"Component {comp_idx+1}",
                color=colors[i],
                alpha=0.8,
            )
        ax2.set_xlabel(self.name_mapping.get("y_variable", y_name))
        ax2.legend()
        ax2.grid(True, linestyle="--", alpha=0.0)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def plot_subgroup_grid(
        self,
        Y,
        mixture_probs,
        n_bins=50,
        grid_cols=3,
        title="Subgroup Distributions vs. Population",
        save_name=None,
        hard_assignment=False,
        density=True,
        disabled_components=None,
        subplot_size=(5, 4),
        y_labels=None,
        plot_subsample=2000,
        plot_type_2d="kde",
    ):
        if Y.shape[0] != mixture_probs.shape[0]:
            raise ValueError(
                "Y and mixture_probs must have the same number of samples."
            )

        if Y.ndim == 1:
            Y = Y.reshape(-1, 1)
        y_dim = Y.shape[1]
        if y_dim not in [1, 2]:
            raise ValueError(
                f"This function only supports 1D or 2D targets, but Y has {y_dim} dimensions."
            )

        is_scatter = y_dim == 2 and plot_type_2d == "scatter"
        if is_scatter:
            hard_assignment = True
        processed_probs = prepare_probs(mixture_probs, hard_assignment)

        Y_plot, probs_plot, assignments_plot = Y, processed_probs, None
        if plot_subsample is not None and Y.shape[0] > plot_subsample:
            sample_indices = np.random.choice(Y.shape[0], plot_subsample, replace=False)
            Y_plot, probs_plot = Y[sample_indices], processed_probs[sample_indices]
        if is_scatter:
            assignments_plot = np.argmax(probs_plot, axis=1)

        n_total_components = mixture_probs.shape[1]
        all_indices = np.arange(n_total_components)
        active_mask = np.ones(n_total_components, dtype=bool)
        if disabled_components is not None:
            disabled_mask = np.array(disabled_components, dtype=bool)
            if disabled_mask.shape[0] == n_total_components:
                active_mask = ~disabled_mask
        component_mass = processed_probs.sum(axis=0)
        active_mask = active_mask & (component_mass > 1e-6)
        active_indices = all_indices[active_mask]

        if len(active_indices) == 0:
            print("No active components to plot.")
            return

        n_active_components = len(active_indices)
        colors = self._get_colors(n_active_components)
        n_rows = int(np.ceil(n_active_components / grid_cols))
        fig, axes = plt.subplots(
            n_rows,
            grid_cols,
            figsize=(subplot_size[0] * grid_cols, subplot_size[1] * n_rows),
            sharex=True,
            sharey=True,
            squeeze=False,
        )
        fig.suptitle(f"{title}", fontsize=16)
        axes_flat = axes.flatten()

        for i, comp_idx in enumerate(active_indices):
            ax = axes_flat[i]
            color = colors[i]
            if y_dim == 1:
                Y_flat = Y_plot.ravel()
                pop_heights, bins = np.histogram(Y_flat, bins=n_bins, density=density)
                bin_width = np.diff(bins)[0]
                ax.bar(
                    bins[:-1],
                    pop_heights,
                    width=bin_width,
                    color="grey",
                    alpha=0.2,
                    align="edge",
                )
                subgroup_heights, _ = np.histogram(
                    Y_flat, bins=bins, weights=probs_plot[:, comp_idx], density=density
                )
                ax.bar(
                    bins[:-1],
                    subgroup_heights,
                    width=bin_width,
                    color=color,
                    alpha=0.7,
                    align="edge",
                )
            elif y_dim == 2:
                if plot_type_2d == "kde":
                    sns.kdeplot(
                        x=Y_plot[:, 0],
                        y=Y_plot[:, 1],
                        ax=ax,
                        color="grey",
                        alpha=0.4,
                        levels=5,
                        linewidths=1,
                    )
                    sns.kdeplot(
                        x=Y_plot[:, 0],
                        y=Y_plot[:, 1],
                        weights=probs_plot[:, comp_idx],
                        ax=ax,
                        fill=True,
                        color=color,
                        alpha=0.6,
                    )
                elif plot_type_2d == "hist":
                    sns.histplot(
                        x=Y_plot[:, 0],
                        y=Y_plot[:, 1],
                        ax=ax,
                        bins=n_bins,
                        color="grey",
                        alpha=0.2,
                    )
                    sns.histplot(
                        x=Y_plot[:, 0],
                        y=Y_plot[:, 1],
                        weights=probs_plot[:, comp_idx],
                        ax=ax,
                        bins=n_bins,
                        color=color,
                        alpha=0.7,
                    )
                elif plot_type_2d == "scatter":
                    ax.scatter(
                        Y_plot[:, 0], Y_plot[:, 1], color="grey", alpha=0.05, s=6
                    )
                    component_mask = assignments_plot == comp_idx
                    ax.scatter(
                        Y_plot[component_mask, 0],
                        Y_plot[component_mask, 1],
                        color=color,
                        alpha=0.5,
                        s=6,
                    )
            ax.set_title(f"Component {comp_idx + 1}")
            ax.grid(True, linestyle="--", alpha=0.5)

        for i in range(n_active_components, len(axes_flat)):
            axes_flat[i].set_visible(False)

        y_label_text = (
            "Density"
            if density and y_dim == 1
            else (y_labels[1] if y_labels and len(y_labels) > 1 else "Y Dimension 2")
        )
        x_label_text = (
            y_labels[0]
            if y_labels
            else (
                self.name_mapping.get("y_variable", "Y")
                if y_dim == 1
                else "Y Dimension 1"
            )
        )
        for r in range(n_rows):
            if axes[r, 0].get_visible():
                axes[r, 0].set_ylabel(y_label_text)
        for c in range(grid_cols):
            last_visible_row = next(
                (r for r in range(n_rows - 1, -1, -1) if axes[r, c].get_visible()), -1
            )
            if last_visible_row != -1:
                axes[last_visible_row, c].set_xlabel(x_label_text)

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def _plot_gmm_densities(
        self,
        Y: np.ndarray,
        gmm_remixer: "GMMRemixer",
        base_gmm,
        scaler_y,
        disabled_components: List[bool],
        n_points=100,
        save_name=None,
        **kwargs,
    ):
        from sklearn.mixture import GaussianMixture

        if not isinstance(base_gmm, GaussianMixture):
            raise TypeError("base_gmm must be a scikit-learn GaussianMixture.")
        if not base_gmm.converged_:
            raise ValueError("The base_gmm must be fitted.")
        if Y.ndim > 1 and Y.shape[1] != 1:
            Y = Y[:, 0].reshape(-1, 1)
        elif Y.ndim == 1:
            Y = Y.reshape(-1, 1)

        y_min, y_max = Y.min(), Y.max()
        margin = (y_max - y_min) * 0.05
        y_range_np = np.linspace(y_min - margin, y_max + margin, n_points).reshape(
            -1, 1
        )
        y_range_scaled_np = scaler_y.transform(y_range_np)
        component_log_probs_np = base_gmm._estimate_log_prob(y_range_scaled_np)
        component_densities = torch.from_numpy(np.exp(component_log_probs_np)).to(
            torch.float32
        )

        norm_weights = gmm_remixer.get_mixing_weights()
        target_device = norm_weights.device
        component_densities = component_densities.to(target_device)

        with torch.no_grad():
            subgroup_densities_scaled = component_densities @ norm_weights

        fig, ax = plt.subplots(figsize=(8, 6))
        active_rules = [
            j for j, disabled in enumerate(disabled_components) if not disabled
        ]
        colors = self._get_colors(len(active_rules))

        scale_product = 1.0
        if hasattr(scaler_y, "scale_") and scaler_y.scale_ is not None:
            scale_product = np.prod(scaler_y.scale_)

        for i, j in enumerate(active_rules):
            density_scaled = subgroup_densities_scaled[:, j].cpu().numpy()
            density_j = density_scaled / scale_product
            ax.plot(
                y_range_np.ravel(),
                density_j,
                label=f"Rule {j+1}",
                color=colors[i],
                zorder=2,
                linewidth=2,
            )

        if norm_weights.shape[1] > gmm_remixer.n_rules:
            density_background_scaled = subgroup_densities_scaled[:, -1].cpu().numpy()
            density_background = density_background_scaled / scale_product
            ax.plot(
                y_range_np.ravel(),
                density_background,
                label="Background",
                color="grey",
                linestyle="--",
                zorder=1,
                linewidth=2,
            )

        ax.set_xlabel(self.name_mapping.get("y_variable", "Y"))
        ax.set_ylabel("Density p_j(y)")
        ax.legend()
        ax.grid(True, linestyle="--", alpha=0.6)
        ax.set_ylim(bottom=0)
        plt.tight_layout()

        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def _plot_flow_densities(
        self,
        Y: np.ndarray,
        component_flows: List[torch.nn.Module],
        scaler_y,
        disabled_components: List[bool],
        n_points=100,
        save_name=None,
        title="Component Densities from Normalizing Flows",
        **kwargs,
    ):
        fig, ax = plt.subplots(figsize=(8, 6))
        y_range = np.linspace(Y.min(), Y.max(), n_points).reshape(-1, 1)
        y_range_scaled = torch.tensor(
            scaler_y.transform(y_range), dtype=torch.float32, device="cpu"
        )

        active_components = [
            i for i, flow in enumerate(component_flows) if not disabled_components[i]
        ]
        colors = self._get_colors(len(active_components))

        scale_product = 1.0
        if hasattr(scaler_y, "scale_") and scaler_y.scale_ is not None:
            scale_product = np.prod(scaler_y.scale_)

        for color_idx, comp_idx in enumerate(active_components):
            flow = component_flows[comp_idx]
            flow.to("cpu")
            with torch.no_grad():
                log_prob = get_log_prob(flow, y_range_scaled)
                if torch.any(torch.isnan(log_prob)):
                    log_prob = torch.nan_to_num(log_prob, nan=-1e8)

                density_scaled = log_prob.exp().cpu().numpy().flatten()
                density = density_scaled / scale_product
            ax.plot(
                y_range.flatten(),
                density,
                label=f"Component {comp_idx + 1}",
                color=colors[color_idx],
                linewidth=2,
                zorder=2,
            )

        ax.set_xlabel(self.name_mapping.get("y_variable", "Y"))
        ax.set_ylabel("Density")
        ax.set_title(title)
        ax.legend()
        ax.grid(True, linestyle="--", alpha=0.6)
        ax.set_ylim(bottom=0)
        plt.tight_layout()

        if save_name:
            full_save_path = self.output_dir / save_name
            plt.savefig(full_save_path, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()

    def plot_gating_heatmap(
        self,
        model: "MixtureModel",
        mode: Literal["activation", "responsibility"] = "responsibility",
        output_file="gating_heatmap.png",
        figsize=None,
        dpi=150,
        grid_resolution=100,
    ):
        """
        Creates a multi-plot heatmap visualization of the gating network's output.
        """
        final_model = _reconstruct_model_from_snapshot(model, model.history[-1])
        colors, _ = _get_component_colors_and_labels(final_model)
        active_indices = [
            i for i, r in enumerate(final_model.rules_model.rules) if not r.disabled
        ]

        if not active_indices:
            print("No active rules to plot.")
            return

        num_plots = len(active_indices)
        if figsize is None:
            figsize = (5 * num_plots, 5)
        fig, axes = plt.subplots(1, num_plots, figsize=figsize, dpi=dpi, squeeze=False)
        axes = axes.flatten()

        X_orig = model.X_original
        x1_min, x1_max = X_orig[:, 0].min(), X_orig[:, 0].max()
        x2_min, x2_max = X_orig[:, 1].min(), X_orig[:, 1].max()
        pad1, pad2 = (x1_max - x1_min) * 0.1, (x2_max - x2_min) * 0.1

        xx, yy = np.meshgrid(
            np.linspace(x1_min - pad1, x1_max + pad1, grid_resolution),
            np.linspace(x2_min - pad2, x2_max + pad2, grid_resolution),
        )
        grid_points_np = np.c_[xx.ravel(), yy.ravel()]

        if final_model.preprocessor.scaler_y and hasattr(
            final_model.preprocessor.scaler_y, "n_features_in_"
        ):
            n_y_features = final_model.preprocessor.scaler_y.n_features_in_
        else:
            n_y_features = (
                final_model.Y_original.shape[1]
                if final_model.Y_original.ndim > 1
                else 1
            )
        dummy_y = np.zeros((grid_points_np.shape[0], n_y_features))

        with torch.no_grad():
            grid_tensor_scaled, _, _, _ = final_model.preprocessor.transform(
                grid_points_np, dummy_y
            )

            if mode == "activation":
                gating_values = final_model.rules_model.forward_raw(grid_tensor_scaled)
            else:
                gating_values, _ = final_model.rules_model(grid_tensor_scaled)

        for i, rule_idx in enumerate(active_indices):
            ax = axes[i]
            z = gating_values[:, rule_idx].cpu().numpy().reshape(xx.shape)

            c = ax.pcolormesh(xx, yy, z, cmap="viridis", shading="auto", vmin=0, vmax=1)
            fig.colorbar(c, ax=ax)
            ax.scatter(X_orig[:, 0], X_orig[:, 1], s=5, c="white", alpha=0.3)

            rule = final_model.rules_model.rules[rule_idx]
            cuts = rule.discretizer.cut_points.detach().cpu().numpy()
            data_to_transform = cuts[:, :, 0].T
            if data_to_transform.ndim == 1:
                data_to_transform = data_to_transform.reshape(-1, 1)
            unscaled_cuts = final_model.preprocessor.scaler_x.inverse_transform(
                data_to_transform
            ).T

            x1_low, x1_high = unscaled_cuts[0, 0], unscaled_cuts[0, 1]
            x2_low, x2_high = unscaled_cuts[1, 0], unscaled_cuts[1, 1]
            rect = patches.Rectangle(
                (x1_low, x2_low),
                (x1_high - x1_low),
                (x2_high - x2_low),
                fill=False,
                ec=colors[rule_idx],
                lw=2.5,
            )
            ax.add_patch(rect)

            ax.set_title(f"Rule {rule_idx + 1} ({mode.capitalize()})")
            ax.set_xlabel("Feature X1")
            ax.set_ylabel("Feature X2")
            ax.set_aspect("equal", adjustable="box")

        plt.tight_layout()
        try:
            plt.savefig(self.output_dir / output_file, dpi=dpi, bbox_inches="tight")
            print(f"Gating heatmap saved to {self.output_dir / output_file}")
        except Exception as e:
            print(f"Error saving plot: {e}")
        finally:
            plt.close(fig)
