"""Pareto plotting functionality."""

from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from data_processor import _pareto_frontier, get_n_schedule, subsample_and_aggregate_n
from plot_base import FrontierPlotter, PlotterMixin, get_schedule_points, setup_color_normalization
from scipy.interpolate import interp1d


class ParetoPlotter(FrontierPlotter, PlotterMixin):
    """Handles Pareto frontier plotting."""

    def plot(self, results: dict, baseline: Optional[dict] = None,
             title: str = "Pareto Frontier", schedule: str = 'end', **kwargs):
        """
        Generate Pareto frontier plot matching the original visual style.

        Parameters
        ----------
        results : dict
            Results data dictionary
        baseline : dict, optional
            Baseline results for comparison
        title : str
            Plot title
        **kwargs
            Additional plotting parameters
        """
        # Extract parameters with defaults
        sample_levels_to_plot = kwargs.get('sample_levels_to_plot', None)
        frontier_method = kwargs.get('frontier_method', 'basic')
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        plot_points = kwargs.get('plot_points', True)
        plot_frontiers = kwargs.get('plot_frontiers', True)
        plot_envelope = kwargs.get('plot_envelope', False)
        threshold = kwargs.get('threshold', None)
        color_scale = kwargs.get('color_scale', 'linear')
        n_x_points = kwargs.get('n_x_points', 10000)
        x_scale = kwargs.get('x_scale', 'log')

        self._validate_inputs(results, baseline)

        # Preprocess data
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )
        _, n_steps, n_total_samples = y.shape
        if sample_levels_to_plot is None:
            sample_levels_to_plot = self.get_sample_levels(n_total_samples)

        # Generate points for plotting using the original method
        pts = []
        rng = np.random.default_rng(42)
        for j in range(1, n_total_samples + 1, 1):
            for i in range(1, n_steps + 1):
                n_vec = get_n_schedule(i, j, schedule)
                pts.append(subsample_and_aggregate_n(
                    n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                    flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                ))

        if not pts:
            raise ValueError("No points generated for plotting")

        pts_array = np.array(pts)
        cost, _, n_samp, mean_p = pts_array.T

        max_cost = max(cost)
        if x_scale == "log":
            x_interp = np.logspace(11, np.log10(max_cost)+0.001, n_x_points)
        else:
            x_interp = np.linspace(0, max_cost+1, n_x_points)

        # Create figure with original layout and size
        fig = plt.figure(figsize=(5.4, 2.4))
        # Main Pareto plot (right side, spanning both rows)
        ax1 = plt.subplot2grid((2, 3), (0, 1), colspan=2, rowspan=2)

        # Setup color normalization
        color_norm = setup_color_normalization(color_scale, n_samp)
        cmap = plt.get_cmap("viridis")

        # Store frontier data for bar charts
        frontier_data = {}
        baseline_max_asr = 0

        # Plot points with original style
        if plot_points:
            # Subsample points for plotting, considering logarithmic cost spacing
            if len(cost) > 1000:
                # Sample uniformly in log space
                log_cost = np.log10(cost + 1e-10)
                log_indices = np.argsort(log_cost)
                step = len(log_indices) // 1000
                subsample_indices = log_indices[::step][:1000]
                cost_sub = cost[subsample_indices]
                mean_p_sub = mean_p[subsample_indices]
                n_samp_sub = n_samp[subsample_indices]
            else:
                cost_sub = cost
                mean_p_sub = mean_p
                n_samp_sub = n_samp

            plt.scatter(cost_sub, mean_p_sub, c=n_samp_sub, cmap="viridis",
                       alpha=0.15, s=3, norm=color_norm)

        # Plot frontiers with original smoothing and style
        if plot_frontiers:
            n_smoothing = 50
            for j in sample_levels_to_plot:
                xs = []
                ys = []
                if j == n_total_samples:
                    n_smoothing = 1

                for _ in range(n_smoothing):
                    pts_smooth = []
                    for i in range(1, n_steps + 1):
                        n_vec = get_n_schedule(i, j, schedule)
                        pts_smooth.append(subsample_and_aggregate_n(
                            n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                            flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                        ))
                    pts_smooth = np.asarray(pts_smooth)
                    cost_smooth, _, _, mean_p_smooth = pts_smooth.T
                    fx, fy = _pareto_frontier(cost_smooth, mean_p_smooth, method=frontier_method)
                    xs.append(fx)
                    ys.append(fy)

                y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False,
                                   fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
                color = cmap(color_norm(j))
                y_mean = np.nanmean(y_interp, axis=0)

                # Filter out NaN values and zeros
                valid_mask = ~np.isnan(y_mean) & (y_mean > 0)
                x_pts = x_interp[valid_mask]
                y_pts = y_mean[valid_mask]

                # Store data for bar charts
                frontier_data[j] = {
                    'x': x_pts,
                    'y': y_pts,
                    'color': color,
                    'max_asr': np.max(y_pts) if np.any(valid_mask) else 0
                }

                # Plot with original style - markers and linewidth
                plt.plot(
                    x_pts,
                    y_pts,
                    marker="o",
                    linewidth=1.8,
                    markersize=2,
                    label=f"{j} samples",
                    color=color,
                )

        if plot_envelope:
            n_smoothing = n_total_samples
            y_interps = []
            for j in range(1, n_total_samples+1):
                xs = []
                ys = []
                for n in range(n_smoothing):
                    pts = []
                    for i in range(1, n_steps + 1):
                        n_vec = get_n_schedule(i, j, schedule)
                        pts.append(subsample_and_aggregate_n(
                            n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                            flops_sampling_generation, rng
                        ))

                    pts = np.asarray(pts)
                    cost, step_idx, n_samp, mean_p = pts.T

                    fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method)
                    xs.append(fx)
                    ys.append(fy)

                y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
                y_interps.append(np.nanmean(y_interp, axis=0))
            y_interps = np.array(y_interps)
            argmax = np.nanargmax(y_interps, axis=0)
            argmax = np.maximum.accumulate(argmax)
            y_envelope = np.nanmax(y_interps, axis=0)

            # Filter out NaN values and zeros
            valid_mask = ~np.isnan(y_envelope) & (y_envelope > 0)
            color = [cmap(color_norm(argmax[i])) for i in range(len(argmax)) if valid_mask[i]]
            plt.scatter(x_interp[valid_mask], y_envelope[valid_mask], c=color, s=2)


        # Process baseline data if provided
        if baseline is not None:
            y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = self._preprocess_data(
                baseline, metric, threshold
            )

            if y_baseline is not None:
                _, n_steps_baseline, n_total_samples_baseline = y_baseline.shape
                assert n_total_samples_baseline == 1

                # Generate baseline points using the original method
                pts_baseline = []
                for i in range(1, n_steps_baseline + 1):
                    for j in range(1, n_total_samples_baseline + 1):
                        n_vec = get_n_schedule(i, j, schedule)
                        pts_baseline.append(subsample_and_aggregate_n(
                            n_vec, y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache,
                            baseline_flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                        ))

                pts_baseline_array = np.array(pts_baseline)
                cost_baseline, _, n_samp_baseline, mean_p_baseline = pts_baseline_array.T
                max_cost_baseline = max(cost_baseline)

                # Create Pareto frontier for baseline like the original
                if plot_frontiers or plot_envelope:
                    mask = n_samp_baseline == 1
                    fx, fy = _pareto_frontier(cost_baseline[mask], mean_p_baseline[mask], method=frontier_method)
                    y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp)
                    if max_cost_baseline / max_cost < 0.7:
                        max_cost_baseline = max_cost
                    valid_mask_baseline = ~np.isnan(y_interp_baseline) & (y_interp_baseline > 0) & (x_interp < max_cost_baseline)
                    # Store baseline data for bar charts
                    baseline_max_asr = np.max(y_interp_baseline[valid_mask_baseline]) if np.any(valid_mask_baseline) else 0
                    baseline_frontier_data = {
                        'x': x_interp[valid_mask_baseline],
                        'y': y_interp_baseline[valid_mask_baseline],
                        'max_asr': baseline_max_asr
                    }
                    plt.plot(
                        x_interp[valid_mask_baseline],
                        y_interp_baseline[valid_mask_baseline],
                        marker="o",
                        linewidth=1.8,
                        markersize=2,
                        label=f"Baseline (greedy)",
                        color="r",
                    )

        # Setup main plot axes
        plt.xlabel("Total FLOPs", fontsize=13)
        if threshold is None:
            plt.ylabel(r"$\mathcal{H}_b$", fontsize=18)
        else:
            plt.ylabel(r"$\text{ASR}_b$", fontsize=18)
        plt.grid(True, alpha=0.3)
        plt.xscale(x_scale)

        # Determine legend location based on title and scale
        if "autodan" in title.lower():
            loc = "lower right"
        elif x_scale == "log":
            loc = "upper left"
        else:
            loc = "lower right"

        # Get legend handles and labels from main plot
        handles, labels = plt.gca().get_legend_handles_labels()

        # Create dedicated legend subplot (left column)
        ax0 = plt.subplot2grid((2, 3), (0, 0), colspan=1, rowspan=2)
        ax0.axis('off')  # Remove all axes

        # Process legend handles and labels like the original
        if handles:
            handles_processed = [*handles[:-1][::-1], handles[-1]] if len(handles) > 1 else handles
            labels_processed = [*labels[:-1][::-1], labels[-1].replace(" (greedy)", "")] if len(labels) > 1 and labels[-1].endswith(" (greedy)") else labels
            ax0.legend(handles_processed, labels_processed, loc='center', fontsize=12)

        plt.tight_layout()

        return fig, ax1

    def _plot_baseline(self, ax, baseline: dict, metric: tuple,
                      threshold: float|None):
        """Plot baseline comparison."""
        baseline_y, baseline_opt_flops, baseline_prefill_flops, baseline_gen_flops = self._preprocess_data(
            baseline, metric, threshold
        )

        baseline_pts = get_schedule_points(
            baseline_y, baseline_opt_flops, baseline_prefill_flops,
            baseline_gen_flops, return_ratio=False, cumulative=False
        )

        if baseline_pts:
            baseline_array = np.array(baseline_pts)
            baseline_cost, _, _, baseline_mean = baseline_array.T
            ax.scatter(baseline_cost, baseline_mean,
                      color='red', marker='x', s=50, label='Baseline', zorder=10)



class NonCumulativeParetoPlotter(ParetoPlotter):
    """Handles non-cumulative Pareto frontier plotting."""

    def plot(self, results: dict, baseline: Optional[dict] = None,
             title: str = "Non-Cumulative Pareto Frontier", schedule: str = 'end', **kwargs):
        """
        Generate non-cumulative Pareto frontier plot using step indices on x-axis.

        This implementation matches the original non_cumulative_pareto_plot function,
        using step indices (as percentages) instead of FLOP costs on the x-axis.
        """
        # Extract parameters with defaults
        sample_levels_to_plot = kwargs.get('sample_levels_to_plot', None)
        frontier_method = 'non_cumulative'  # Force non-cumulative method
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        plot_points = kwargs.get('plot_points', True)
        plot_frontiers = kwargs.get('plot_frontiers', True)
        plot_envelope = kwargs.get('plot_envelope', False)
        threshold = kwargs.get('threshold', None)
        color_scale = kwargs.get('color_scale', 'linear')
        n_x_points = kwargs.get('n_x_points', 10000)

        self._validate_inputs(results, baseline)

        # Preprocess data
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )
        _, n_steps, n_total_samples = y.shape
        if sample_levels_to_plot is None:
            sample_levels_to_plot = self.get_sample_levels(n_total_samples)

        # Generate points for plotting using step indices instead of FLOP costs
        pts = []
        rng = np.random.default_rng(42)
        for j in range(1, n_total_samples + 1, 1):
            for i in range(1, n_steps + 1):
                n_vec = get_n_schedule(i, j, schedule)
                pts.append(subsample_and_aggregate_n(
                    n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                    flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                ))

        if not pts:
            raise ValueError("No points generated for plotting")

        pts_array = np.array(pts)
        cost, step_idx, n_samp, mean_p = pts_array.T

        # Use step indices converted to percentages for x-axis (like original)
        step_percentages = (step_idx / (n_steps - 1)) * 100
        max_step = max(step_percentages)

        # Create x interpolation range for step percentages
        x_interp = np.linspace(0, max_step+1, n_x_points)

        # Create figure with original layout and size
        fig = plt.figure(figsize=(5.4, 2.8))
        # Main Pareto plot (right side, spanning both rows)
        ax1 = plt.subplot2grid((2, 3), (0, 1), colspan=2, rowspan=2)

        # Setup color normalization
        color_norm = setup_color_normalization(color_scale, n_samp)
        cmap = plt.get_cmap("viridis")

        # Plot points with original style (using step percentages)
        if plot_points:
            # Subsample points for plotting
            if len(step_percentages) > 1000:
                indices = np.argsort(step_percentages)
                step = len(indices) // 1000
                subsample_indices = indices[::step][:1000]
                step_sub = step_percentages[subsample_indices]
                mean_p_sub = mean_p[subsample_indices]
                n_samp_sub = n_samp[subsample_indices]
            else:
                step_sub = step_percentages
                mean_p_sub = mean_p
                n_samp_sub = n_samp

            plt.scatter(step_sub, mean_p_sub, c=n_samp_sub, cmap="viridis",
                       alpha=0.15, s=3, norm=color_norm)

        # Plot frontiers with step indices on x-axis
        if plot_frontiers:
            n_smoothing = 50
            for j in sample_levels_to_plot:
                xs = []
                ys = []
                if j == n_total_samples:
                    n_smoothing = 1

                for _ in range(n_smoothing):
                    pts_smooth = []
                    for i in range(1, n_steps + 1):
                        n_vec = get_n_schedule(i, j, schedule)
                        pts_smooth.append(subsample_and_aggregate_n(
                            n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                            flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                        ))
                    pts_smooth = np.asarray(pts_smooth)
                    cost_smooth, step_idx_smooth, _, mean_p_smooth = pts_smooth.T
                    # Convert step indices to percentages and use for frontier
                    step_percentages_smooth = (step_idx_smooth / (n_steps - 1)) * 100
                    fx, fy = _pareto_frontier(step_percentages_smooth, mean_p_smooth, method=frontier_method)
                    xs.append(fx)
                    ys.append(fy)

                y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False,
                                   fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
                color = cmap(color_norm(j))
                y_mean = np.nanmean(y_interp, axis=0)

                # Filter out NaN values and zeros
                valid_mask = ~np.isnan(y_mean) & (y_mean > 0)
                x_pts = x_interp[valid_mask]
                y_pts = y_mean[valid_mask]

                # Plot with original style
                plt.plot(
                    x_pts,
                    y_pts,
                    marker="o",
                    linewidth=1.8,
                    markersize=2,
                    label=f"{j} samples",
                    color=color,
                )

        # Process baseline data if provided
        if baseline is not None:
            y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = self._preprocess_data(
                baseline, metric, threshold
            )

            if y_baseline is not None:
                n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape
                assert n_total_samples_baseline == 1

                # Generate baseline points using step indices
                pts_baseline = []
                for i in range(1, n_steps_baseline + 1):
                    for j in range(1, n_total_samples_baseline + 1):
                        n_vec = get_n_schedule(i, j, schedule)
                        pts_baseline.append(subsample_and_aggregate_n(
                            n_vec, y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache,
                            baseline_flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                        ))

                pts_baseline_array = np.array(pts_baseline)
                cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts_baseline_array.T
                step_percentages_baseline = (step_idx_baseline / (n_steps_baseline - 1)) * 100
                max_step_baseline = max(step_percentages_baseline)

                # Create Pareto frontier for baseline using step percentages
                if plot_frontiers or plot_envelope:
                    mask = n_samp_baseline == 1
                    fx, fy = _pareto_frontier(step_percentages_baseline[mask], mean_p_baseline[mask], method=frontier_method)
                    y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp)
                    if max_step_baseline / max_step < 0.7:
                        max_step_baseline = max_step
                    valid_mask_baseline = ~np.isnan(y_interp_baseline) & (y_interp_baseline > 0) & (x_interp < max_step_baseline)
                    plt.plot(
                        x_interp[valid_mask_baseline],
                        y_interp_baseline[valid_mask_baseline],
                        marker="o",
                        linewidth=1.8,
                        markersize=2,
                        label=f"Baseline (greedy)",
                        color="r",
                    )

        # Setup main plot axes with step percentages
        plt.xlabel("Optimization Step (\%)", fontsize=13)
        if threshold is None:
            plt.ylabel(r"$\mathcal{H}_b$", fontsize=18)
        else:
            plt.ylabel(r"$\text{ASR}_b$", fontsize=18)
        plt.grid(True, alpha=0.3)
        plt.xlim(0, 100)

        # Get legend handles and labels from main plot
        handles, labels = plt.gca().get_legend_handles_labels()

        # Create dedicated legend subplot (left column)
        ax0 = plt.subplot2grid((2, 3), (0, 0), colspan=1, rowspan=2)
        ax0.axis('off')  # Remove all axes

        handles_processed = [*handles[:-1][::-1], handles[-1]] if len(handles) > 1 else handles
        labels_processed = [*labels[:-1][::-1], labels[-1].replace(" (greedy)", "")] if len(labels) > 1 and labels[-1].endswith(" (greedy)") else labels
        ax0.legend(handles_processed, labels_processed, loc='center', fontsize=12)

        plt.tight_layout()
        return fig, ax1
