"""Analysis plotting functionality (FLOPs ratio, breakdown, etc.)."""

import numpy as np
import matplotlib.pyplot as plt
from typing import Optional

from plot_base import AnalysisPlotter, PlotterMixin, setup_color_normalization, get_schedule_points


class FlopsRatioPlotter(AnalysisPlotter, PlotterMixin):
    """Handles FLOPs ratio analysis plotting."""

    def plot(self, results: dict, baseline: Optional[dict] = None,
             title: str = "FLOPs Ratio Analysis", schedule: str = 'end', **kwargs):
        """
        Generate FLOPs ratio plot matching original implementation.

        Parameters
        ----------
        results : dict
            Results data dictionary
        baseline : dict, optional
            Baseline results for comparison
        title : str
            Plot title
        **kwargs
            Additional plotting parameters
        """
        from data_processor import generate_sample_sizes

        # Extract parameters
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        cumulative = kwargs.get('cumulative', False)
        threshold = kwargs.get('threshold', None)
        color_scale = kwargs.get('color_scale', 'linear')
        sample_levels_to_plot = kwargs.get('sample_levels_to_plot', None)
        verbose = kwargs.get('verbose', True)

        self._validate_inputs(results, baseline)

        # Preprocess data
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )
        n_runs, n_steps, n_total_samples = y.shape

        if sample_levels_to_plot is None:
            sample_levels_to_plot = generate_sample_sizes(n_total_samples)

        # Generate points with ratio calculation
        pts = get_schedule_points(
            y, flops_optimization, flops_sampling_prefill_cache,
            flops_sampling_generation, return_ratio=True, cumulative=cumulative, schedule_type=schedule
        )

        if not pts:
            raise ValueError("No points generated for plotting")

        pts_array = np.array(pts)
        ratio, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts_array.T

        # Calculate total FLOPs for coloring option
        total_flop = opt_flop + sampling_flop

        # Filter out infinite ratios for plotting
        finite_mask = np.isfinite(ratio)
        ratio_finite = ratio[finite_mask]
        mean_p_finite = mean_p[finite_mask]
        n_samp_finite = n_samp[finite_mask]
        total_flop_finite = total_flop[finite_mask]

        # Create figure
        fig, ax = self._setup_figure(title, figsize=(10, 6))

        # Create dual color encoding: hue based on samples, strength based on total FLOPs
        # Normalize sample counts for hue
        sample_norm = setup_color_normalization("linear", n_samp_finite)
        # Normalize total FLOPs for alpha/strength
        flops_norm = setup_color_normalization(color_scale, total_flop_finite)

        # Get base colors from viridis colormap based on sample count
        cmap = plt.get_cmap("viridis")
        base_colors = cmap(sample_norm(n_samp_finite))

        # Scatter plot with dual color encoding
        sc = ax.scatter(ratio_finite, mean_p_finite, c=base_colors, s=15, alpha=0.05)

        # Create custom colorbar for samples (hue)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=sample_norm)
        sm.set_array([])

        # Highlight specific sample levels
        for j in sample_levels_to_plot:
            mask = (n_samp == j) & finite_mask
            if np.any(mask):
                # Use the same dual coloring for highlighted points
                highlight_base_color = cmap(sample_norm(j))
                # Create colors for this sample level

                ax.scatter(ratio[mask], mean_p[mask],
                          c=[highlight_base_color], s=50, alpha=0.9,
                          edgecolors='black', linewidth=0.5,
                          label=f"{j} samples")

        # Add baseline if provided
        if baseline is not None:
            y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = self._preprocess_data(
                baseline, metric, threshold
            )

            baseline_pts = get_schedule_points(
                y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache,
                baseline_flops_sampling_generation, return_ratio=True, cumulative=cumulative, schedule_type=schedule
            )
            baseline_pts_array = np.array(baseline_pts)
            baseline_ratio, _, baseline_n_samp, baseline_mean_p, _, _ = baseline_pts_array.T

            baseline_finite_mask = np.isfinite(baseline_ratio) & np.isfinite(baseline_mean_p)
            if np.any(baseline_finite_mask):
                # For baseline, just plot the raw ratios
                ax.scatter(baseline_ratio[baseline_finite_mask], baseline_mean_p[baseline_finite_mask],
                          color="red", s=60, alpha=0.9, marker="^",
                          edgecolors='black', linewidth=0.5, label="Baseline", zorder=6)

        # Add subtle iso-FLOP lines (fitted quadratics)
        if n_total_samples == 500:  # Only if we have enough data points
            # Select 5 FLOP levels spanning the range
            flop_min, flop_max = np.min(total_flop_finite), np.max(total_flop_finite)
            iso_flop_levels = np.logspace(np.log10(flop_min), np.log10(flop_max), 5)

            for i, flop_level in enumerate(iso_flop_levels):
                # Find points near this FLOP level (within 15% tolerance)
                tolerance = 0.15
                near_flop_mask = np.abs(total_flop_finite - flop_level) / flop_level < tolerance

                if np.sum(near_flop_mask) >= 3:  # Need at least 3 points for quadratic fit
                    x_iso = ratio_finite[near_flop_mask]
                    y_iso = mean_p_finite[near_flop_mask]

                    # Sort by x for smooth curve
                    sort_idx = np.argsort(x_iso)
                    x_iso_sorted = x_iso[sort_idx]
                    y_iso_sorted = y_iso[sort_idx]

                    # Fit quadratic in log-space for x
                    try:
                        log_x = np.log10(x_iso_sorted)
                        coeffs = np.polyfit(log_x, y_iso_sorted, 2)

                        # Generate smooth curve
                        x_smooth = np.logspace(np.log10(x_iso_sorted.min()) - 0.25,
                                             np.log10(x_iso_sorted.max()) + 0.25, 50)
                        log_x_smooth = np.log10(x_smooth)
                        y_smooth = np.polyval(coeffs, log_x_smooth)

                        # Plot the iso-FLOP line with label for first one only
                        label = "Iso-FLOP lines" if i == 0 else None
                        ax.plot(x_smooth, y_smooth, '--', color='gray', alpha=0.8,
                               linewidth=1, zorder=1, label=label)

                        # Add text annotation for FLOP level at the end of the curve
                        if len(x_smooth) > 0 and len(y_smooth) > 0:
                            # Find a good position for the text (middle of the curve)
                            mid_idx = 0
                            text_x = x_smooth[mid_idx]
                            text_y = y_smooth[mid_idx]

                            # Format FLOP level in scientific notation
                            flop_text = f"{flop_level:.1e}"
                            ax.text(text_x, text_y, flop_text, fontsize=8, alpha=0.8,
                                   ha='center', va='top', color='black',
                                   bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
                                            alpha=0.7, edgecolor='none'))

                    except (np.linalg.LinAlgError, ValueError) as e:
                        if verbose:
                            print(f"Polynomial fit failed: {e}")

        # Setup axes
        ax.set_xlabel("Sampling FLOPs / Total FLOPs", fontsize=14)
        if threshold is None:
            ax.set_ylabel(r"$\mathbb{E}[h(Y)]$", fontsize=14)
        else:
            ax.set_ylabel(r"$ASR$", fontsize=14)

        ax.grid(True, alpha=0.3)
        ax.set_xscale("log")
        ax.set_xlim(1e-5, 1)
        ax.set_ylim(bottom=0)
        # plt.legend(loc='upper left')
        plt.tight_layout()

        if verbose:
            print(f"FLOPs ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}")
            print(f"Mean s_harm range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}")
            print(f"Total FLOPs range: {total_flop_finite.min():.2e} to {total_flop_finite.max():.2e}")

        return fig, ax


class IdealRatioPlotter(AnalysisPlotter, PlotterMixin):
    """Handles ideal sampling FLOPs ratio plotting."""

    def plot(self, results: dict, baseline: Optional[dict] = None,
             title: str = "Ideal Sampling FLOPs Ratio", schedule: str = 'end', **kwargs):
        """Generate ideal ratio plot."""
        # Extract parameters
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        cumulative = kwargs.get('cumulative', False)
        threshold = kwargs.get('threshold', None)
        n_p_harmful_points = kwargs.get('n_p_harmful_points', 100)

        self._validate_inputs(results, baseline)

        # Preprocess data
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )

        # Generate points with ratio calculation
        pts = get_schedule_points(
            y, flops_optimization, flops_sampling_prefill_cache,
            flops_sampling_generation, return_ratio=True, cumulative=cumulative, schedule_type=schedule
        )

        if not pts:
            raise ValueError("No points generated for plotting")

        pts_array = np.array(pts)
        ratio, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts_array.T

        # Create figure
        fig, ax = self._setup_figure(title, figsize=(10, 6))

        # Calculate ideal ratios for different p_harmful levels
        p_harmful_levels = np.linspace(0, max(mean_p), n_p_harmful_points)
        ideal_ratios = []

        for p_level in p_harmful_levels:
            # Find points at this p_harmful level (within tolerance)
            tolerance = 0.01
            mask = np.abs(mean_p - p_level) < tolerance

            if np.any(mask):
                # Find minimum total cost at this p_harmful level
                level_ratios = ratio[mask]
                ideal_ratios.append(np.min(level_ratios))
            else:
                ideal_ratios.append(np.nan)

        # Plot ideal ratio curve
        valid_mask = ~np.isnan(ideal_ratios)
        ax.plot(p_harmful_levels[valid_mask], np.array(ideal_ratios)[valid_mask],
                'r-', linewidth=2, label='Ideal Ratio')

        # Plot all points for comparison
        ax.scatter(mean_p, ratio, alpha=0.3, c='blue', label='All Points')

        # Setup axes
        ax.set_xlabel(r"$\mathcal{H}$")
        ax.set_ylabel('Optimal Sampling FLOPs Ratio')
        ax.grid(True, alpha=0.3)
        ax.legend()

        return fig, ax


class FlopsBreakdownPlotter(AnalysisPlotter, PlotterMixin):
    """Handles FLOPs breakdown plotting with 2D contour surface."""

    def plot(self, results: dict, baseline: Optional[dict] = None,
             title: str = "FLOPs Breakdown Analysis", schedule: str = 'end', **kwargs):
        """
        Generate FLOPs breakdown plot matching the original implementation.

        Creates a 2D contour plot showing optimization FLOPs vs sampling FLOPs
        with p_harmful as a color-coded surface.
        """
        from scipy.interpolate import griddata
        import logging

        # Extract parameters
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        cumulative = kwargs.get('cumulative', False)
        threshold = kwargs.get('threshold', None)
        verbose = kwargs.get('verbose', True)

        self._validate_inputs(results, baseline)

        # Preprocess data (matching original)
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )

        # Generate points using schedule-based approach
        pts = get_schedule_points(
            y, flops_optimization, flops_sampling_prefill_cache,
            flops_sampling_generation, return_ratio=True, cumulative=cumulative, schedule_type=schedule,
        )

        if not pts:
            raise ValueError("No points generated for plotting")

        pts_array = np.array(pts)
        _, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts_array.T

        # Extract components for breakdown plot
        opt_flops = opt_flop  # This is just optimization FLOPs
        sampling_flops = sampling_flop # This is just sampling FLOPs
        p_harmful_vals = mean_p
        n_samples_vals = n_samp


        # Create 2D surface plot using griddata interpolation (matching original)
        sampling_min, sampling_max = sampling_flops.min(), sampling_flops.max()
        opt_min, opt_max = opt_flops.min(), opt_flops.max()

        # Use log space for sampling FLOPs if range is large (matching original)
        if sampling_max / sampling_min > 100:
            sampling_grid = np.logspace(np.log10(sampling_min), np.log10(sampling_max), 100)
        else:
            sampling_grid = np.linspace(sampling_min, sampling_max, 100)

        # Use log space for optimization FLOPs if range is large (matching original)
        if opt_max / opt_min > 100:
            opt_grid = np.logspace(np.log10(opt_min), np.log10(opt_max), 100)
        else:
            opt_grid = np.linspace(opt_min, opt_max, 100)

        Sampling_grid, Opt_grid = np.meshgrid(sampling_grid, opt_grid)

        # Interpolate p_harmful values onto the grid
        try:
            p_harmful_grid = griddata(
                (sampling_flops, opt_flops),
                p_harmful_vals,
                (Sampling_grid, Opt_grid),
                method='linear',
                rescale=True
            )

            # Handle NaN values with nearest neighbor fallback
            if np.isnan(p_harmful_grid).any():
                print("Filling NaN values with nearest neighbor")
                p_harmful_grid_nearest = griddata(
                    (sampling_flops, opt_flops),
                    p_harmful_vals,
                    (Sampling_grid, Opt_grid),
                    method='nearest',
                    fill_value=0,
                    rescale=True
                )
                fill_mask = np.isnan(p_harmful_grid)
                # Impossible regions only exist when using total flops as x axis, but now we're doing opt flops only
                # impossible_mask = ((Sampling_grid + opt_min) > Opt_grid) | ((Opt_grid-Sampling_grid) > opt_max-Sampling_grid)
                # p_harmful_grid[fill_mask] = p_harmful_grid_nearest[fill_mask]
                # p_harmful_grid[impossible_mask] = np.nan

        except Exception as e:
            if verbose:
                logging.info(f"Linear interpolation failed: {e}, trying nearest neighbor")
            p_harmful_grid = griddata(
                (sampling_flops, opt_flops),
                p_harmful_vals,
                (Sampling_grid, Opt_grid),
                method='nearest',
                fill_value=0,
                rescale=True
            )
        # Apply Gaussian blur to p_harmful values to smooth the frontier
        from scipy.ndimage import gaussian_filter

        # Create a blurred version of the p_harmful grid for frontier computation
        sigma = 1.25  # Blur strength - adjust as needed
        p_harmful_blurred = gaussian_filter(p_harmful_grid, sigma=sigma)
        # p_harmful_grid = p_harmful_blurred

        # Create figure matching original size
        plt.close()
        plt.figure(figsize=(4, 2.8))

        # Create contour plot (transpose the grids, matching original)
        levels = np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 50)
        contour = plt.contourf(Opt_grid, Sampling_grid, p_harmful_grid, levels=levels,
                              cmap='viridis', extend='both')

        # Add colorbar (matching original)
        cbar = plt.colorbar(contour, ticks=np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 5))
        cbar.ax.set_yticklabels([f'{tick:.2f}' for tick in cbar.get_ticks()])
        if threshold is None:
            cbar.set_label(r"$\mathcal{H}_b$", fontsize=17)
        else:
            cbar.set_label(r"$ASR$", fontsize=17)

        # Add compute-optimal frontier
        total_flops = opt_flops + sampling_flops

        valid_mask = ~np.isnan(p_harmful_blurred)
        if np.any(valid_mask):
            opt_coords = Opt_grid[valid_mask]
            sampling_coords = Sampling_grid[valid_mask]
            p_harmful_coords = p_harmful_blurred[valid_mask]
            total_flops_coords = opt_coords + sampling_coords

            # For each cost level, keep the point with highest p_harmful
            order = np.argsort(total_flops_coords)
            costs_sorted = total_flops_coords[order]
            p_harmful_sorted = p_harmful_coords[order]
            opt_sorted = opt_coords[order]
            sampling_sorted = sampling_coords[order]

            # Only keep points that improve on previous best p_harmful
            frontier_points = [(0, 0, 0)]
            best_p_so_far = 0

            for i, (cost, p_val, opt_val, samp_val) in enumerate(zip(costs_sorted, p_harmful_sorted, opt_sorted, sampling_sorted)):
                if p_val > best_p_so_far:  # Strictly better
                    frontier_points.append((opt_val, samp_val, p_val))
                    best_p_so_far = p_val

            if len(frontier_points) > 1:
                frontier_points.append((opt_sorted[-1], sampling_sorted[-1], frontier_points[-1][2]))

        show_legend = False
        if len(frontier_points) > 1:
            frontier_points = np.array(frontier_points)

            # Clip frontier points to only show where contour data is valid
            # Create 1D grid arrays matching the interpolated grid
            opt_grid_1d = np.linspace(opt_min, opt_max, p_harmful_grid.shape[0])
            sampling_grid_1d = np.linspace(sampling_min, sampling_max, p_harmful_grid.shape[1])

            # Filter frontier points to only include those in valid contour regions
            clipped_points = []
            for point in frontier_points:
                opt_val, sampling_val, p_val = point

                # Check if point is within grid bounds
                if (opt_min <= opt_val <= opt_max and
                    sampling_min <= sampling_val <= sampling_max):

                    # Find nearest grid indices
                    opt_idx = np.clip(np.searchsorted(opt_grid_1d, opt_val), 0, len(opt_grid_1d)-1)
                    sampling_idx = np.clip(np.searchsorted(sampling_grid_1d, sampling_val), 0, len(sampling_grid_1d)-1)

                    # Only include if the nearest grid point has valid (non-NaN) data
                    if not np.isnan(p_harmful_grid[opt_idx, sampling_idx]):
                        clipped_points.append(point)

            # Remove duplicate points and sort by optimization FLOPs for smooth line
            unique_points = []
            seen_coords = set()
            points_to_use = clipped_points if len(clipped_points) > 1 else frontier_points

            for point in points_to_use:
                coord_key = (round(point[0], 10), round(point[1], 10))  # Round to avoid floating point issues
                if coord_key not in seen_coords:
                    unique_points.append(point)
                    seen_coords.add(coord_key)

            if len(unique_points) > 1:
                unique_points = np.array(unique_points)
                # Sort by optimization FLOPs for a smooth curve
                sort_idx = np.argsort(unique_points[:, 0])
                sorted_points = unique_points[sort_idx]
                # Smooth the frontier line vertically using a simple moving average
                if len(sorted_points) > 2:
                    window_size = min(7, len(sorted_points))

                    smoothed_y = np.convolve(np.concatenate(([sorted_points[:window_size//2, 1], sorted_points[:, 1], sorted_points[-window_size//2+1:, 1]])), np.ones(window_size)/window_size, mode='valid')
                    sorted_points[:, 1] = smoothed_y

                plt.plot(sorted_points[:, 0], sorted_points[:, 1],
                        color='black', linewidth=2, linestyle="-", alpha=0.8, label="Compute-optimal frontier")
                show_legend = True


        # Setup axes (matching original)
        plt.xlabel('Optimization FLOPs', fontsize=13)
        plt.ylabel('Sampling FLOPs', fontsize=13)

        # Use log scale if the range warrants it (matching original logic)
        if sampling_max / sampling_min > 10:
            plt.yscale('log')
        if opt_max / opt_min > 10:
            plt.xscale('log')

        # Add legend if compute-optimal frontier is present
        if show_legend:
            plt.legend(loc='lower left', fontsize=13, bbox_to_anchor=(-0.1, 0.95))

        plt.tight_layout()

        if verbose:
            logging.info(f"Sampling FLOPs range: {sampling_flops.min():.2e} to {sampling_flops.max():.2e}")
            logging.info(f"Optimization FLOPs range: {opt_flops.min():.2e} to {opt_flops.max():.2e}")

        return plt.gcf(), plt.gca()