import matplotlib.pyplot as plt
import torch
import numpy as np
import matplotlib.patches as patches
import matplotlib.gridspec as gridspec
import matplotlib.animation as animation
from typing import TYPE_CHECKING, List

from .base import BasePlotter
from .utils import _reconstruct_model_from_snapshot, _get_component_colors_and_labels

if TYPE_CHECKING:
    from ..mixture_model_base import MixtureModel
    from ..mixture_model_flows import FlowMixtureModel
    from ..mixture_model_remix import RemixMixtureModel


class TrainingPlotter(BasePlotter):
    """
    A helper class to create visualizations of the mixture model training process.
    """

    def plot_training_animation_2d(
        self,
        model: "MixtureModel",
        output_file="training_animation_2d.mp4",
        n_frames=None,
        figsize=(15, 10),
        dpi=120,
        fps=5,
        legend=False,
    ):
        """
        Creates a 2D animation of the training process, showing the evolution of
        rule partitions and expert densities.
        """
        history = model.history
        if not history:
            print("Error: Training history is empty. Cannot create animation.")
            return

        X_orig = model.X_original
        Y_orig = model.Y_original
        config = model.config

        if X_orig.shape[1] != 2:
            raise ValueError(
                f"This visualization is for 2D feature spaces, but got {X_orig.shape[1]} features."
            )

        fig = plt.figure(figsize=figsize, dpi=dpi)
        gs = gridspec.GridSpec(
            2, 2, height_ratios=[4, 1], width_ratios=[3, 2], figure=fig
        )
        plt.subplots_adjust(wspace=0.3, hspace=0.3)

        main_ax = fig.add_subplot(gs[0, 0])
        density_ax = fig.add_subplot(gs[0, 1])
        loss_ax = fig.add_subplot(gs[1, :])

        # Configure axes
        x1_min, x1_max = X_orig[:, 0].min(), X_orig[:, 0].max()
        x2_min, x2_max = X_orig[:, 1].min(), X_orig[:, 1].max()
        pad1, pad2 = (x1_max - x1_min) * 0.1, (x2_max - x2_min) * 0.1
        main_ax.set_xlim(x1_min - pad1, x1_max + pad1)
        main_ax.set_ylim(x2_min - pad2, x2_max + pad2)
        main_ax.set_xlabel("X1")
        main_ax.set_ylabel("X2")
        main_ax.set_title("Feature Space Partitioning")
        scatter_plot = main_ax.scatter(
            X_orig[:, 0], X_orig[:, 1], s=3, c="grey", alpha=0.7
        )

        y_min, y_max = Y_orig.min(), Y_orig.max()
        y_pad = (y_max - y_min) * 0.1
        density_ax.set_xlim(y_min - y_pad, y_max + y_pad)
        density_ax.set_xlabel("Y")
        density_ax.set_ylabel("Density")
        density_ax.set_title("Component Densities")

        loss_ax.set_xlabel("Training Step")
        loss_ax.set_ylabel("NLL Loss")
        (loss_line,) = loss_ax.plot([], [], "r-", label="NLL")
        loss_ax.legend()
        fig.suptitle("", fontsize=16)

        # Frame selection
        if n_frames is not None and n_frames < len(history):
            frame_indices = np.linspace(0, len(history) - 1, n_frames, dtype=int)
            selected_history = [history[i] for i in frame_indices]
        else:
            selected_history = history

        steps, nll_losses = [], []
        n_interpretable_rules = config.n_mixture_components
        use_background_comp = config.use_background_component
        n_total_components = (
            n_interpretable_rules + 1 if use_background_comp else n_interpretable_rules
        )

        # Create a temporary model to determine color mapping from the final state
        temp_model = _reconstruct_model_from_snapshot(model, selected_history[-1])
        colors, labels = _get_component_colors_and_labels(temp_model)

        # Initialize plot artists
        rule_rects = [
            patches.Rectangle((0, 0), 1, 1, fill=False, alpha=0.8, lw=2.5)
            for _ in range(n_interpretable_rules)
        ]
        for i, rect in enumerate(rule_rects):
            main_ax.add_patch(rect)
            rect.set_color(colors[i])
            rect.set_edgecolor(colors[i] * 0.8)
            rect.set_visible(False)

        density_lines = [
            density_ax.plot([], [], color=colors[i], lw=2, label=labels[i])[0]
            for i in range(n_total_components)
        ]
        if legend:
            density_ax.legend(fontsize="small", loc="upper right")

        def update(frame_idx):
            snapshot = selected_history[frame_idx]

            # Update loss plot
            steps.append(snapshot.step)
            nll_losses.append(snapshot.losses.get("nll", np.nan))
            loss_line.set_data(steps, nll_losses)
            loss_ax.relim()
            loss_ax.autoscale_view()
            loss_ax.set_title(
                f"Training Loss - Step: {snapshot.step}, Temp: {snapshot.current_temp:.4f}"
            )

            # Reconstruct model state for this frame
            frame_model = _reconstruct_model_from_snapshot(model, snapshot)

            # Update scatter plot colors
            with torch.no_grad():
                responsibilities = frame_model.get_responsibilities(X_orig)
                winner_indices = np.argmax(responsibilities, axis=1)
                point_colors = colors[winner_indices]
                scatter_plot.set_color(point_colors)

            # Update rule rectangles
            for i, rect in enumerate(rule_rects):
                if snapshot.disabled_components[i]:
                    rect.set_visible(False)
                else:
                    rect.set_visible(True)
                    rule = frame_model.rules_model.rules[i]
                    cuts = rule.discretizer.cut_points.detach().cpu().numpy()

                    # FIX: Reshape data for scaler to prevent 1D array errors
                    data_to_transform = cuts[:, :, 0].T
                    if data_to_transform.ndim == 1:
                        data_to_transform = data_to_transform.reshape(-1, 1)

                    unscaled_cuts = frame_model.preprocessor.scaler_x.inverse_transform(
                        data_to_transform
                    ).T
                    x1_low, x1_high = unscaled_cuts[0, 0], unscaled_cuts[0, 1]
                    x2_low, x2_high = unscaled_cuts[1, 0], unscaled_cuts[1, 1]
                    rect.set_xy((x1_low, x2_low))
                    rect.set_width(x1_high - x1_low)
                    rect.set_height(x2_high - x2_low)

            # Update density plots
            y_grid = np.linspace(y_min, y_max, 200).reshape(-1, 1)
            all_density_data = self._update_density_lines(
                density_ax, density_lines, frame_model, y_grid
            )
            if all_density_data:
                max_density = np.max(all_density_data)
                density_ax.set_ylim(
                    0, max_density * 1.15 if max_density > 1e-9 else 1.0
                )

            return [scatter_plot, loss_line] + rule_rects + density_lines

        anim = animation.FuncAnimation(
            fig, update, frames=len(selected_history), blit=True, interval=1000 / fps
        )
        writer = (
            animation.FFMpegWriter(fps=fps)
            if output_file.endswith(".mp4")
            else animation.PillowWriter(fps=fps)
        )
        try:
            anim.save(self.output_dir / output_file, writer=writer, dpi=dpi)
            print(f"Animation saved to {self.output_dir / output_file}")
        except Exception as e:
            print(f"Error saving animation: {e}. Ensure ffmpeg or pillow is installed.")
        finally:
            plt.close(fig)

    def plot_training_snapshots_2d(
        self,
        model: "MixtureModel",
        snapshot_steps: List[int],
        output_file="training_snapshots_2d.png",
        figsize=None,
        dpi=150,
    ):
        """
        Creates a static multi-plot visualization of rule partitions at specific training steps.
        """
        history = model.history
        if not history:
            print("Error: Training history is empty. Cannot create visualization.")
            return

        X_orig = model.X_original
        num_snapshots = len(snapshot_steps)
        if num_snapshots == 0:
            return

        if figsize is None:
            figsize = (5 * num_snapshots, 5.5)
        fig, axes = plt.subplots(
            1, num_snapshots, figsize=figsize, dpi=dpi, squeeze=False
        )
        axes = axes.flatten()

        # Determine colors from the final model state
        final_model = _reconstruct_model_from_snapshot(model, history[-1])
        colors, _ = _get_component_colors_and_labels(final_model)

        snapshots_to_plot = {s.step: s for s in history if s.step in snapshot_steps}
        x1_min, x1_max = X_orig[:, 0].min(), X_orig[:, 0].max()
        x2_min, x2_max = X_orig[:, 1].min(), X_orig[:, 1].max()
        pad1, pad2 = (x1_max - x1_min) * 0.1, (x2_max - x2_min) * 0.1
        xlims, ylims = (x1_min - pad1, x1_max + pad1), (x2_min - pad2, x2_max + pad2)

        for i, step in enumerate(snapshot_steps):
            ax = axes[i]
            ax.set_title(f"Step: {step}")
            ax.set_xlabel("Feature X1")
            if i == 0:
                ax.set_ylabel("Feature X2")
            ax.set_xlim(xlims)
            ax.set_ylim(ylims)
            ax.set_aspect("equal", adjustable="box")

            snapshot = snapshots_to_plot.get(step)

            if snapshot is None:
                ax.text(
                    0.5,
                    0.5,
                    f"Step {step}\nnot found",
                    ha="center",
                    va="center",
                    color="red",
                )
                continue

            frame_model = _reconstruct_model_from_snapshot(model, snapshot)
            with torch.no_grad():
                responsibilities = frame_model.get_responsibilities(X_orig)
                winner_indices = np.argmax(responsibilities, axis=1)
                point_colors = colors[winner_indices]

            ax.scatter(X_orig[:, 0], X_orig[:, 1], s=5, c=point_colors, alpha=0.7)

            for j, rule in enumerate(frame_model.rules_model.rules):
                if not snapshot.disabled_components[j]:
                    cuts = rule.discretizer.cut_points.detach().cpu().numpy()
                    unscaled_cuts = model.preprocessor.scaler_x.inverse_transform(
                        cuts[:, :, 0].T
                    ).T
                    x1_low, x1_high = unscaled_cuts[0, 0], unscaled_cuts[0, 1]
                    x2_low, x2_high = unscaled_cuts[1, 0], unscaled_cuts[1, 1]
                    rect = patches.Rectangle(
                        (x1_low, x2_low),
                        (x1_high - x1_low),
                        (x2_high - x2_low),
                        fill=False,
                        alpha=0.8,
                        color=colors[j],
                        ec=colors[j] * 0.8,
                        lw=2.5,
                    )
                    ax.add_patch(rect)

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        try:
            plt.savefig(self.output_dir / output_file, dpi=dpi, bbox_inches="tight")
            print(f"Snapshot plot saved to {self.output_dir / output_file}")
        except Exception as e:
            print(f"Error saving plot: {e}")
        finally:
            plt.close(fig)

    def _update_density_lines(self, ax, lines, model, y_grid_np):
        """Helper to compute and plot densities for both Flow and GMMRemix models."""
        all_density_data = []
        y_tensor = torch.tensor(y_grid_np, dtype=torch.float32, device=model.device)

        if "FlowMixtureModel" in str(model.__class__):
            model: "FlowMixtureModel"
            with torch.no_grad():
                component_log_probs = model.expert_model(y_tensor)
                densities = torch.exp(component_log_probs).cpu().numpy()

            for i, line in enumerate(lines):
                if i < densities.shape[1]:
                    line.set_data(y_grid_np, densities[:, i])
                    all_density_data.extend(densities[:, i])
                else:
                    line.set_data([], [])

        elif "RemixMixtureModel" in str(model.__class__):
            model: "RemixMixtureModel"
            with torch.no_grad():
                # FIX: Create a dummy X with the correct number of features and reshape y_scaled_np
                dummy_x_for_transform = np.zeros(
                    (y_grid_np.shape[0], model.X_scaled.shape[1])
                )
                y_scaled_np = (
                    model.preprocessor.transform(dummy_x_for_transform, y_grid_np)[1]
                    .cpu()
                    .numpy()
                )
                if y_scaled_np.ndim == 1:
                    y_scaled_np = y_scaled_np.reshape(-1, 1)

                global_densities_on_grid = np.exp(
                    model.gmm_model._estimate_log_prob(y_scaled_np)
                )

                norm_weights = (
                    torch.softmax(model.expert_model.mixing_weights, dim=0)
                    .cpu()
                    .numpy()
                )
                remixed_densities = global_densities_on_grid @ norm_weights

            for i, line in enumerate(lines):
                if i < remixed_densities.shape[1]:
                    scaler_y = model.preprocessor.scaler_y
                    if (
                        scaler_y
                        and hasattr(scaler_y, "scale_")
                        and scaler_y.scale_ is not None
                    ):
                        density_rescaled = remixed_densities[:, i] / np.prod(
                            scaler_y.scale_
                        )
                    else:
                        density_rescaled = remixed_densities[:, i]

                    line.set_data(y_grid_np, density_rescaled)
                    all_density_data.extend(density_rescaled)
                else:
                    line.set_data([], [])

        return all_density_data
