"""Histogram plotting functionality."""

from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots
import seaborn as sns
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator
from plot_base import HistogramPlotter, PlotterMixin


class StandardHistogramPlotter(HistogramPlotter, PlotterMixin):
    """Handles standard histogram plotting."""

    def plot(self, sampled_data: dict, model_title: str, atk_name: str,
             cfg: dict, threshold: Optional[float] = None, **kwargs):
        """
        Generate histogram plot.

        Parameters
        ----------
        sampled_data : dict
            Sampled data dictionary
        model_title : str
            Model title for the plot
        atk_name : str
            Attack name
        cfg : dict
            Configuration dictionary
        threshold : float, optional
            Threshold value for binary classification
        **kwargs
            Additional plotting parameters
        """
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))

        # Extract data
        if metric not in sampled_data:
            raise ValueError(f"Metric {metric} not found in sampled_data")

        y = np.array(sampled_data[metric])  # (B, n_steps, n_samples)

        # Apply threshold if specified
        if threshold is not None:
            y = y > threshold

        # Get final step data for histogram
        final_step_data = y[:, -1, :].flatten()  # Flatten across batches and samples

        # Create figure
        fig, ax = self._setup_figure(f"{model_title} {cfg.get('title_suffix', '')}", figsize=(10, 6))

        # Plot histogram
        if threshold is not None:
            # Binary histogram
            bins = [0, 0.5, 1]
            ax.hist(final_step_data.astype(float), bins=bins, alpha=0.7, edgecolor='black')
            ax.set_xlabel(f"{metric} > {threshold}")
            ax.set_xticks([0, 1])
            ax.set_xticklabels(['False', 'True'])
        else:
            # Continuous histogram
            ax.hist(final_step_data, bins=50, alpha=0.7, edgecolor='black')
            ax.set_xlabel(str(metric))

        ax.set_ylabel('Frequency')
        ax.grid(True, alpha=0.3)

        return fig, ax


class RidgePlotter(HistogramPlotter, PlotterMixin):
    """Handles ridge plot (multiple histograms stacked vertically)."""

    def _log_spaced_indices(self, n_cols: int, k: int = 4) -> list[int]:
        """
        Return k log-spaced column indices in [0, n_cols-1] inclusive.
        Guarantees 0 and n_cols-1 are present; deduplicates if n_cols is small.
        """
        # corner cases: 0 or 1 column → just [0]; 2–3 cols → all of them
        if n_cols <= k:
            return list(range(n_cols))

        max_idx = n_cols - 1
        # make (k) points geometrically spaced in (1 … max_idx)
        inner = np.geomspace(1, max_idx, num=k, dtype=int)

        # build the final list and drop duplicates, then sort
        idx = np.unique(np.concatenate(([0], inner, [max_idx])))
        # if de-duplication left us with fewer than k values, pad with lin-spaced ones
        if idx.size < k:
            extra = np.linspace(0, max_idx, num=k, dtype=int)
            idx = np.unique(np.concatenate((idx, extra)))[:k]

        return idx.tolist()

    def plot(self, sampled_data: dict, model_title: str, cfg: dict,
             threshold: Optional[float] = None, **kwargs):
        """
        Generate ridge plot showing distributions across steps.

        Parameters
        ----------
        sampled_data : dict
            Sampled data dictionary
        model_title : str
            Model title for the plot
        cfg : dict
            Configuration dictionary
        threshold : float, optional
            Threshold value for binary classification
        **kwargs
            Additional plotting parameters
        """
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))

        # Extract data
        if metric not in sampled_data:
            raise ValueError(f"Metric {metric} not found in sampled_data")

        data = np.array(sampled_data[metric])  # (B, n_steps, n_samples)

        # Set up seaborn theme
        sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0), 'figure.figsize': (3, 3)})

        # Select representative steps using log spacing
        step_idxs = self._log_spaced_indices(data.shape[1], 4)

        # Prepare data for ridge plot
        ridge_data = []
        for step_idx in step_idxs:
            step_data = data[:, step_idx, :].flatten()  # Get p_harmful values for this step
            for value in step_data:
                ridge_data.append({'step': f'Step {step_idx}', r"$h(Y)$": value})
        df = pd.DataFrame(ridge_data)

        # Create ridge plot for p_harmful distributions across steps
        unique_steps = sorted(df['step'].unique(), key=lambda x: int(x.split()[1]))
        n_steps = len(unique_steps)
        pal = sns.cubehelix_palette(n_steps, rot=-.25, light=.7)

        # Initialize the FacetGrid object
        g = sns.FacetGrid(df, row="step", hue="step", aspect=5, height=.4, palette=pal,
                            row_order=unique_steps)

        # Draw the densities
        g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, clip=(0, 1), fill=True, alpha=1, linewidth=0, zorder=1)
        g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, clip=(0, 1), color="w", lw=3, zorder=0)

        # Add vertical lines for mean, median, and 95th percentile
        def add_mean_lines(x, **kwargs):
            ax = plt.gca()
            mean_val = np.mean(x)
            median_val = np.median(x)
            percentile_95 = np.percentile(x, 95)
            ax.axvline(median_val, color='black', linestyle='--', alpha=0.7, linewidth=1, ymax=0.5)
            ax.axvline(percentile_95, color='blue', linestyle='--', alpha=0.7, linewidth=1, ymax=0.5)
            ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.5)

        g.map(add_mean_lines, r"$h(Y)$")

        # Add reference line at y=0
        g.refline(y=0, linewidth=1, linestyle="-", color=None, clip_on=False)

        # Set the subplots to overlap
        g.figure.subplots_adjust(hspace=-.4)

        # Remove axes details that don't play well with overlap
        g.set_titles("")
        g.set(yticks=[], ylabel="")
        g.despine(bottom=True, left=True)
        g.set_xlabels(r"$h(Y)$", fontsize=14)
        g.set(xlim=(0, 1))
        plt.style.use("science")

        # Create legend elements
        legend_elements = [
            Line2D([0], [0], color='red', lw=1, alpha=0.7, label=r'$\text{Mean}$'),
            Line2D([0], [0], color='black', linestyle='--', lw=1, alpha=0.7, label=r'$\text{Median}$'),
            Line2D([0], [0], color='blue', linestyle='--', lw=1, alpha=0.7, label=r'$\text{95th Percentile}$')
        ]

        g.figure.suptitle(f"{model_title}", fontsize=14, y=0.9, va="top")

        return g.figure, g.axes



class RidgeSideBySidePlotter(HistogramPlotter, PlotterMixin):
    """Handles side-by-side ridge plot."""

    def _log_spaced_indices(self, n_cols: int, k: int = 4) -> list[int]:
        """
        Return k log-spaced column indices in [0, n_cols-1] inclusive.
        Guarantees 0 and n_cols-1 are present; deduplicates if n_cols is small.
        """
        # corner cases: 0 or 1 column → just [0]; 2–3 cols → all of them
        if n_cols <= k:
            return list(range(n_cols))

        max_idx = n_cols - 1
        # make (k) points geometrically spaced in (1 … max_idx)
        inner = np.geomspace(1, max_idx, num=k, dtype=int)

        # build the final list and drop duplicates, then sort
        idx = np.unique(np.concatenate(([0], inner, [max_idx])))
        # if de-duplication left us with fewer than k values, pad with lin-spaced ones
        if idx.size < k:
            extra = np.linspace(0, max_idx, num=k, dtype=int)
            idx = np.unique(np.concatenate((idx, extra)))[:k]

        return idx.tolist()

    def plot(self, sampled_data: dict, model_title: str, cfg: dict,
             threshold: Optional[float] = None, **kwargs):
        """
        Generate side-by-side ridge plot showing distributions across steps.

        Parameters
        ----------
        sampled_data : dict
            Sampled data dictionary
        model_title : str
            Model title for the plot
        cfg : dict
            Configuration dictionary
        threshold : float, optional
            Threshold value for binary classification
        **kwargs
            Additional plotting parameters
        """
        from data_processor import generate_sample_sizes
        from matplotlib.patches import FancyArrowPatch
        from scipy.interpolate import interp1d

        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))

        # Extract data
        if metric not in sampled_data:
            raise ValueError(f"Metric {metric} not found in sampled_data")

        data = np.array(sampled_data[metric])  # (B, n_steps, n_samples)

        n_steps_to_show = 4

        # Basic theming
        sns.set_theme(
            style="white",
            rc={
                "axes.facecolor": (0, 0, 0, 0),
                "figure.figsize": (1.5 * n_steps_to_show, 1.5),   # widen for columns
            },
        )

        # Choose exactly n_steps_to_show equally-spaced indices
        all_step_idxs = [0] + list(generate_sample_sizes(data.shape[1]-1))
        if data.shape[1]-1 not in all_step_idxs:
            all_step_idxs.append(data.shape[1])

        # Select exactly n_steps_to_show equally-spaced indices
        if len(all_step_idxs) > n_steps_to_show:
            # Use numpy to select evenly spaced indices
            indices = np.linspace(0, len(all_step_idxs) - 1, n_steps_to_show + 1, dtype=int)
            indices = [indices[0], *indices[2:]]
            step_idxs = [all_step_idxs[i] for i in indices]
        else:
            step_idxs = all_step_idxs

        ridge_rows = []
        for idx in step_idxs:
            ridge_rows.extend(
                {
                    "step": f"Step {idx}",
                    r"$h(Y)$": val,
                }
                for val in data[:, idx, :].ravel()
            )

        df = pd.DataFrame(ridge_rows)

        # Build the faceted plot
        unique_steps = sorted(df["step"].unique(), key=lambda x: int(x.split()[1]))
        pal = sns.cubehelix_palette(int(len(unique_steps)*1.5), rot=-0.25, light=0.7)

        g = sns.FacetGrid(
            df,
            col="step",
            hue="step",
            palette=pal,
            col_order=unique_steps,
            sharey=False,          # independent y-axis per column
            aspect=1,
            height=2.5,
        )

        # Densities
        g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, fill=True, alpha=1, linewidth=0, zorder=1)
        g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, color="w", lw=3, zorder=0)

        g.set(yticks=[], ylabel="")      # hide y-ticks

        for ax in g.axes.flat:
            # Get the KDE line data
            kde_line = ax.lines[0]  # The first line should be the KDE plot
            x_data = kde_line.get_xdata()
            y_data = kde_line.get_ydata()

            # Find the density value at h(Y)=0 by interpolating
            if len(x_data) > 0 and len(y_data) > 0:
                # Find the closest x value to 0 or interpolate
                if 0 in x_data:
                    density_at_zero = y_data[x_data == 0][0]
                else:
                    # Interpolate to find density at x=0
                    if x_data.min() <= 0 <= x_data.max():
                        interp_func = interp1d(x_data, y_data, kind='linear', bounds_error=False, fill_value=0)
                        density_at_zero = interp_func(0)
                    else:
                        density_at_zero = 0

                # Set a single y-tick at this density value
                ax.set_yticks([density_at_zero/2, density_at_zero])
                ax.set_yticklabels([f'{density_at_zero/2:.1f}', f'{density_at_zero:.1f}'])
                ax.tick_params(axis='y', labelsize=12, pad=-2)

        # Central-tendency & cut-off lines
        def add_mean_lines(x, **kwargs):
            ax = plt.gca()
            mean_val = np.mean(x)
            median_val = np.median(x)
            p95 = np.percentile(x, 95)
            ax.axvline(median_val, ls="--", lw=1, color="black", ymax=0.5, alpha=0.7)
            ax.axvline(p95,       ls="--", lw=1, color="blue",  ymax=0.5, alpha=0.7)
            ax.axvline(mean_val,  ls="-",  lw=1, color="red",   ymax=0.5, alpha=0.7)

        g.map(add_mean_lines, r"$h(Y)$")

        # Aesthetics
        g.set_titles("")                 # no subplot headers
        tick_vals = np.linspace(0, 1, 6)  # [0. , 0.2, 0.4, 0.6, 0.8, 1.]
        g.set(xticks=tick_vals)
        for ax in g.axes.flat:
            ax.tick_params(axis='x', pad=0, labelsize=12)
        g.set_xlabels("Harmfulness", fontsize=14)
        g.set(xlim=(0, 1))
        plt.style.use("science")

        # Y-axis label on first facet
        first_ax = g.axes.flat[0]
        first_ax.set_ylabel("Density", fontsize=13)

        # Make sure the other facets stay unlabeled
        for ax in g.axes.flat[1:]:
            ax.set_ylabel("")

        # Add "Step x" labels
        for ax, step in zip(g.axes.flat, unique_steps):
            if step[-1] == "9":
                step = "Step " + str(int(step.split()[1])+1)
            ax.text(
                0.5, 0.95, step,                       # centered just above each panel
                ha="center", va="bottom",
                transform=ax.transAxes,
                fontsize=12,
                fontweight="bold"                      # optional, adjust to taste
            )

        # Build a single, vertical legend that mimics the example image
        legend_elements = [
            Line2D([0], [0], color="black",  lw=1, label="Median", ls="--"),
            Line2D([0], [0], color="red",   lw=1, label="Greedy"),
            Line2D([0], [0], color="blue",  lw=1, label="95th percentile", ls="--"),
        ]

        # Determine if we have a single subplot
        if len(g.axes.flat) == 1:
            bbox_anchor = (0.2, 0.8)
        else:
            bbox_anchor = (0.055, 0.8)

        g.figure.legend(
            handles=legend_elements,
            loc="upper left",              # anchor to top-left of the figure
            bbox_to_anchor=bbox_anchor,   # fine-tune position (x, y in fig-coords)
            frameon=False,
            ncol=1,                        # vertical stack
            handletextpad=0.4,
            labelspacing=0.3,
            borderaxespad=0.0,
        )

        # Add horizontal time arrow above the plots
        if len(g.axes.flat) > 1:
            # Get the positions of the first and last subplots
            first_ax = g.axes.flat[0]
            last_ax = g.axes.flat[-1]

            # Get the positions in figure coordinates
            first_pos = first_ax.get_position()
            last_pos = last_ax.get_position()

            # Calculate arrow position (slightly above the plots)
            arrow_y = first_pos.y1 + 0.08  # 8% above the top of the plots
            arrow_start_x = first_pos.x0 + 0.1 * first_pos.width  # 10% into first subplot
            arrow_end_x = last_pos.x1 - 0.1 * last_pos.width     # 90% into last subplot

            # Create and add the arrow patch
            arrow = FancyArrowPatch((arrow_start_x, arrow_y), (arrow_end_x, arrow_y),
                                   connectionstyle="arc3",
                                   arrowstyle='-|>',
                                   mutation_scale=10,
                                   linewidth=0.75,
                                   color='black',
                                   alpha=1.0,
                                   transform=g.figure.transFigure)
            g.figure.patches.append(arrow)

        return g.figure, g.axes


class HistogramTwoPlotter(HistogramPlotter, PlotterMixin):
    """Handles histogram plotting with two-panel comparison."""

    def plot(self, sampled_data: dict, model_title: str, cfg: dict,
             threshold: Optional[float] = None, **kwargs):
        """
        Generate two-panel histogram plot.

        Parameters
        ----------
        sampled_data : dict
            Sampled data dictionary
        model_title : str
            Model title for the plot
        cfg : dict
            Configuration dictionary
        threshold : float, optional
            Threshold value for binary classification
        **kwargs
            Additional plotting parameters
        """
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))

        # Extract data
        if metric not in sampled_data:
            raise ValueError(f"Metric {metric} not found in sampled_data")

        y = np.array(sampled_data[metric])  # (B, n_steps, n_samples)

        # Apply threshold if specified
        if threshold is not None:
            y = y > threshold

        # Create figure with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        fig.suptitle(f"{model_title} Histogram Comparison")

        # First subplot: First step distribution
        first_step_data = y[:, 0, :].flatten()
        if threshold is not None:
            bins = [0, 0.5, 1]
            ax1.hist(first_step_data.astype(float), bins=bins, alpha=0.7, edgecolor='black')
            ax1.set_xlabel(f"{metric} > {threshold}")
            ax1.set_xticks([0, 1])
            ax1.set_xticklabels(['False', 'True'])
        else:
            ax1.hist(first_step_data, bins=30, alpha=0.7, edgecolor='black')
            ax1.set_xlabel(str(metric))

        ax1.set_ylabel('Frequency')
        ax1.set_title('First Step')
        ax1.grid(True, alpha=0.3)

        # Second subplot: Final step distribution
        final_step_data = y[:, -1, :].flatten()
        if threshold is not None:
            ax2.hist(final_step_data.astype(float), bins=bins, alpha=0.7, edgecolor='black')
            ax2.set_xlabel(f"{metric} > {threshold}")
            ax2.set_xticks([0, 1])
            ax2.set_xticklabels(['False', 'True'])
        else:
            ax2.hist(final_step_data, bins=30, alpha=0.7, edgecolor='black')
            ax2.set_xlabel(str(metric))

        ax2.set_ylabel('Frequency')
        ax2.set_title('Final Step')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        return fig, (ax1, ax2)


class RatioPlotter(HistogramPlotter, PlotterMixin):
    """Handles ratio line plots showing P(¬refusal) and P(harmful | ¬refusal) over optimization steps."""

    def plot(self, sampled_data: dict, model_title: str, cfg: dict,
             threshold: Optional[float] = None, **kwargs):
        """
        Generate ratio line plot showing refusal and harmfulness ratios over optimization steps.

        This plot shows:
        - P(¬refusal): Probability of non-refusal (values >= 0.1)
        - P(harmful | ¬refusal): Conditional probability of harmfulness given non-refusal

        Parameters
        ----------
        sampled_data : dict
            Sampled data dictionary
        model_title : str
            Model title for the plot
        cfg : dict
            Configuration dictionary
        threshold : float, optional
            Threshold value for binary classification
        **kwargs
            Additional plotting parameters
        """
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))

        # Extract data
        if metric not in sampled_data:
            raise ValueError(f"Metric {metric} not found in sampled_data")

        data = np.array(sampled_data[metric])  # (B, n_steps, n_samples)

        # Calculate ratios for each step
        num_steps = data.shape[1]
        ratios_1 = []  # P(harmful | ¬refusal)
        ratios_2 = []  # P(¬refusal)

        for step_idx in range(num_steps):
            vals = data[:, step_idx, :].flatten()

            # Original ratio: [0.5,1.0] vs [0.1,1.0] = P(harmful | ¬refusal)
            n_low_1 = np.sum((vals > 0.50) & (vals <= 1.0))
            n_high_1 = np.sum((vals > 0.10) & (vals <= 1.0))
            ratio_1 = n_low_1 / n_high_1 if n_high_1 else np.nan
            ratios_1.append(ratio_1)

            # New ratio: [0.1,1.0] vs [0.0,1.0] = P(¬refusal)
            n_low_2 = np.sum((vals >= 0.10) & (vals <= 1.0))
            n_high_2 = np.sum((vals >= 0.00) & (vals <= 1.0))
            ratio_2 = n_low_2 / n_high_2 if n_high_2 else np.nan
            ratios_2.append(ratio_2)

        # Create figure for ratio plot
        fig, ax = plt.subplots(1, 1, figsize=(6.5, 1.9875) ) # 2.65 normal
        plt.style.use("science")

        # Create ratio plots
        sns.lineplot(x=np.arange(num_steps), y=ratios_2, label=r"$P(\text{¬refusal})$",
                    ax=ax, marker="o" if num_steps == 1 else None)
        sns.lineplot(x=np.arange(num_steps), y=ratios_1, linestyle="--",
                    label=r"$P(\text{harmful} \mid \text{¬refusal})$",
                    ax=ax, marker="x" if num_steps == 1 else None)

        # Configure axes
        ax.yaxis.set_major_locator(MaxNLocator(nbins="auto", integer=False))
        ax.set_xlabel("Step", fontsize=14)
        ax.set_ylabel("Frequency", fontsize=14)
        ax.set_title(f"{model_title}")
        # Set y-axis limits with minimum bottom of 0.3
        current_ylim = ax.get_ylim()
        if current_ylim[0] > 0.3:
            ax.set_ylim(bottom=0.275)
        # Clean tick styling
        ax.tick_params(axis='both', which='both', top=False, right=False, left=True, bottom=True)

        # Place legend to the left of the subplot
        ax.legend(bbox_to_anchor=(-0.3, 0.5), loc='center right')

        plt.tight_layout()

        return fig, ax