"""Bar chart plotting functionality."""

from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from data_processor import _pareto_frontier, get_n_schedule, subsample_and_aggregate_n
from plot_base import FrontierPlotter, PlotterMixin, setup_color_normalization
from scipy.interpolate import interp1d


class BarChartPlotter(FrontierPlotter, PlotterMixin):
    """Handles standalone bar chart plotting for analysis comparisons."""

    def plot(self, results: dict, baseline: Optional[dict] = None,
             title: str = "Bar Chart Analysis", schedule: str = 'end', **kwargs) -> list:
        """
        Generate standalone bar charts for analysis comparison.

        Parameters
        ----------
        results : dict
            Results data dictionary
        baseline : dict, optional
            Baseline results for comparison
        title : str
            Plot title prefix
        schedule : str
            Schedule type
        **kwargs
            Additional plotting parameters

        Returns
        -------
        list
            List of generated figure objects
        """
        # Extract parameters with defaults
        sample_levels_to_plot = kwargs.get('sample_levels_to_plot', None)
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        threshold = kwargs.get('threshold', None)
        color_scale = kwargs.get('color_scale', 'linear')
        n_x_points = kwargs.get('n_x_points', 10000)
        x_scale = kwargs.get('x_scale', 'log')

        self._validate_inputs(results, baseline)

        # Preprocess data
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )
        _, n_steps, n_total_samples = y.shape
        if sample_levels_to_plot is None:
            sample_levels_to_plot = self.get_sample_levels(n_total_samples)

        # Generate points for plotting using the original method
        pts = []
        rng = np.random.default_rng(42)
        for j in range(1, n_total_samples + 1, 1):
            for i in range(1, n_steps + 1):
                n_vec = get_n_schedule(i, j, schedule)
                pts.append(subsample_and_aggregate_n(
                    n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                    flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                ))

        if not pts:
            raise ValueError("No points generated for plotting")

        pts_array = np.array(pts)
        cost, _, n_samp, mean_p = pts_array.T

        max_cost = max(cost)
        if x_scale == "log":
            x_interp = np.logspace(11, np.log10(max_cost)+0.001, n_x_points)
        else:
            x_interp = np.linspace(0, max_cost+1, n_x_points)

        # Setup color normalization
        color_norm = setup_color_normalization(color_scale, n_samp)
        cmap = plt.get_cmap("viridis")

        # Store frontier data for bar charts
        frontier_data = {}
        baseline_max_asr = 0
        baseline_frontier_data = None

        # Generate frontier data for sample levels
        for j in sample_levels_to_plot:
            xs = []
            ys = []
            n_smoothing = 50 if j != n_total_samples else 1

            for _ in range(n_smoothing):
                pts_smooth = []
                for i in range(1, n_steps + 1):
                    n_vec = get_n_schedule(i, j, schedule)
                    pts_smooth.append(subsample_and_aggregate_n(
                        n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                        flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                    ))
                pts_smooth = np.asarray(pts_smooth)
                cost_smooth, _, _, mean_p_smooth = pts_smooth.T
                fx, fy = _pareto_frontier(cost_smooth, mean_p_smooth, method="non_cumulative")
                xs.append(fx)
                ys.append(fy)

            y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False,
                               fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
            color = cmap(color_norm(j))
            y_mean = np.nanmean(y_interp, axis=0)

            # Filter out NaN values and zeros
            valid_mask = ~np.isnan(y_mean) & (y_mean > 0)
            x_pts = x_interp[valid_mask]
            y_pts = y_mean[valid_mask]
            # Store data for bar charts
            frontier_data[j] = {
                'x': x_pts,
                'y': y_pts,
                'color': color,
                'max_asr': np.max(y_pts) if np.any(valid_mask) else 0
            }

        # Process baseline data if provided
        if baseline is not None:
            y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = self._preprocess_data(
                baseline, metric, threshold
            )

            if y_baseline is not None:
                n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape
                assert n_total_samples_baseline == 1

                # Generate baseline points using the original method
                pts_baseline = []
                for i in range(1, n_steps_baseline + 1):
                    for j in range(1, n_total_samples_baseline + 1):
                        n_vec = get_n_schedule(i, j, schedule)
                        pts_baseline.append(subsample_and_aggregate_n(
                            n_vec, y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache,
                            baseline_flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
                        ))

                pts_baseline_array = np.array(pts_baseline)
                cost_baseline, _, n_samp_baseline, mean_p_baseline = pts_baseline_array.T
                max_cost_baseline = max(cost_baseline)

                # Create Pareto frontier for baseline
                mask = n_samp_baseline == 1
                fx, fy = _pareto_frontier(cost_baseline[mask], mean_p_baseline[mask], method="non_cumulative")
                y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp)
                if max_cost_baseline / max_cost < 0.7:
                    max_cost_baseline = max_cost
                valid_mask_baseline = ~np.isnan(y_interp_baseline) & (y_interp_baseline > 0) & (x_interp < max_cost_baseline)
                # Store baseline data for bar charts
                baseline_max_asr = np.max(y_interp_baseline[valid_mask_baseline]) if np.any(valid_mask_baseline) else 0
                baseline_frontier_data = {
                    'x': x_interp[valid_mask_baseline],
                    'y': y_interp_baseline[valid_mask_baseline],
                    'max_asr': baseline_max_asr
                }

        # Generate all three bar charts as separate figures
        figures = []

        # ASR Delta Bar Chart
        plt.close()
        fig_asr = self.plot_asr_delta(frontier_data, baseline_frontier_data, baseline_max_asr,
                                     sample_levels_to_plot, threshold, title)
        if fig_asr:
            figures.append(fig_asr)

        # FLOPs Efficiency Bar Chart
        fig_flops = self.plot_flops_efficiency(frontier_data, baseline_frontier_data, baseline_max_asr,
                                              sample_levels_to_plot, title)
        if fig_flops:
            figures.append(fig_flops)

        # Speedup Bar Chart
        fig_speedup = self.plot_speedup(frontier_data, baseline_frontier_data, baseline_max_asr,
                                       sample_levels_to_plot, title)
        if fig_speedup:
            figures.append(fig_speedup)

        return figures

    def plot_asr_delta(self, frontier_data: dict, baseline_frontier_data: dict|None,
                      baseline_max_asr: float, sample_levels_to_plot: tuple,
                      threshold: float|None, title: str = "ASR Delta"):
        """Plot ASR delta comparison bar chart."""
        methods = []
        max_asrs = []
        colors = []

        # Calculate bar chart margin multiplier like original
        bar_chart_margin_multiplier = 5

        # Add baseline (delta = 0 for baseline)
        if baseline_frontier_data is not None:
            methods.append("Baseline")
            max_asrs.append(0.0)  # Delta from itself is 0
            colors.append("red")

        # Add sampling methods (calculate delta from baseline)
        for j in sample_levels_to_plot:
            if j in frontier_data:
                methods.append(f"{j} samples" if j != 1 else "1 sample")
                delta_asr = frontier_data[j]['max_asr'] - baseline_max_asr if baseline_frontier_data is not None else 0
                max_asrs.append(delta_asr)
                colors.append(frontier_data[j]['color'])

        if not methods:
            return None

        # Create figure
        fig, ax = plt.subplots(figsize=(6, 4))

        bars = ax.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black')
        if threshold is None:
            ax.set_ylabel(r"$\Delta$ $\mathcal{H}_q$", fontsize=17)
        else:
            ax.set_ylabel(r"$\Delta$ $\text{ASR}_q$", fontsize=17)
        ax.set_title(f"{title}", fontsize=14)
        plt.xticks(rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')

        # Increase ylim by margin
        ymin, ymax = ax.get_ylim()
        margin = (ymax - ymin) * 0.03 * bar_chart_margin_multiplier
        ax.set_ylim(ymin - margin, ymax + margin)

        # Add value labels on bars
        for bar, value in zip(bars, max_asrs):
            offset_pt = 4
            va = 'bottom' if value >= 0 else 'top'
            offset = (0, offset_pt if value >= 0 else -offset_pt)
            ax.annotate(f'{value:.2f}',
                        xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                        xytext=offset,
                        textcoords='offset points',
                        ha='center', va=va, fontsize=10)

        plt.tight_layout()
        return fig

    def plot_flops_efficiency(self, frontier_data: dict, baseline_frontier_data: dict|None,
                             baseline_max_asr: float, sample_levels_to_plot: tuple, title: str = "FLOPs Efficiency"):
        """Plot FLOPs efficiency comparison bar chart."""
        if baseline_frontier_data is None or baseline_max_asr <= 0:
            return None

        methods_flops = []
        flops_required = []
        colors_flops = []

        # Find FLOPs required to reach baseline ASR for each sampling method
        target_asr = baseline_max_asr

        for j in sample_levels_to_plot:
            if j in frontier_data:
                # Find the minimum FLOPs where ASR >= target_asr
                y_vals = frontier_data[j]['y']
                x_vals = frontier_data[j]['x']

                # Find points where ASR >= target_asr
                valid_indices = y_vals >= target_asr
                if np.any(valid_indices):
                    min_flops = np.min(x_vals[valid_indices])
                    methods_flops.append(f"{j} samples")
                    flops_required.append(min_flops)
                    colors_flops.append(frontier_data[j]['color'])

        # Add baseline (find minimum FLOPs where it reaches target ASR)
        if baseline_frontier_data['x'].size > 0:
            # Find the minimum FLOPs where baseline ASR >= target_asr
            baseline_y_vals = baseline_frontier_data['y']
            baseline_x_vals = baseline_frontier_data['x']
            baseline_valid_indices = baseline_y_vals >= target_asr
            if np.any(baseline_valid_indices):
                baseline_flops = np.min(baseline_x_vals[baseline_valid_indices])
            else:
                # Fallback to minimum FLOPs if no point reaches target ASR
                baseline_flops = np.min(baseline_x_vals)
            methods_flops.insert(0, "Baseline")
            flops_required.insert(0, baseline_flops)
            colors_flops.insert(0, "red")

        if not methods_flops:
            return None

        # Create figure
        fig, ax = plt.subplots(figsize=(6, 4))

        bars = ax.bar(methods_flops, flops_required, color=colors_flops, alpha=0.7, edgecolor='black')
        ax.set_ylabel("FLOPs to match baseline", fontsize=12)
        ax.set_title(f"{title} - FLOPs Efficiency", fontsize=14)
        plt.xticks(rotation=45, ha='right')
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3, axis='y')

        # Increase ylim by margin
        ymin, ymax = ax.get_ylim()
        import math
        margin = ((math.log10(ymax) - math.log10(ymin)) * 0.2)
        ax.set_ylim(ymin, ymax * (1+margin))

        plt.tight_layout()
        return fig

    def plot_speedup(self, frontier_data: dict, baseline_frontier_data: dict|None,
                    baseline_max_asr: float, sample_levels_to_plot: tuple, title: str = "Speedup Analysis"):
        """Plot speedup comparison bar chart."""
        if baseline_frontier_data is None or baseline_max_asr <= 0:
            return None

        # First get the FLOPs data
        methods_flops = []
        flops_required = []
        colors_flops = []

        # Find FLOPs required to reach baseline ASR for each sampling method
        target_asr = baseline_max_asr

        for j in sample_levels_to_plot:
            if j in frontier_data:
                # Find the minimum FLOPs where ASR >= target_asr
                y_vals = frontier_data[j]['y']
                x_vals = frontier_data[j]['x']

                # Find points where ASR >= target_asr
                valid_indices = y_vals >= target_asr
                if np.any(valid_indices):
                    min_flops = np.min(x_vals[valid_indices])
                    methods_flops.append(f"{j} samples")
                    flops_required.append(min_flops)
                    colors_flops.append(frontier_data[j]['color'])

        # Add baseline FLOPs
        if baseline_frontier_data['x'].size > 0:
            baseline_y_vals = baseline_frontier_data['y']
            baseline_x_vals = baseline_frontier_data['x']
            baseline_valid_indices = baseline_y_vals >= target_asr
            if np.any(baseline_valid_indices):
                baseline_flops = np.min(baseline_x_vals[baseline_valid_indices])
            else:
                baseline_flops = np.min(baseline_x_vals)
            methods_flops.insert(0, "Baseline")
            flops_required.insert(0, baseline_flops)
            colors_flops.insert(0, "red")

        if not methods_flops or methods_flops[0] != "Baseline":
            return None

        # Calculate speedup for each method (baseline_flops / method_flops)
        speedup_methods = []
        speedups = []
        speedup_colors = []

        baseline_flops = flops_required[0]

        for i, (method, flops, color) in enumerate(zip(methods_flops, flops_required, colors_flops)):
            if method != "Baseline":  # Skip baseline itself
                speedup = baseline_flops / flops if flops > 0 else 0
                speedup_methods.append(method)
                speedups.append(speedup)
                speedup_colors.append(color)

        if not speedup_methods:
            return None

        # Create figure
        fig, ax = plt.subplots(figsize=(6, 4))

        bars = ax.bar(speedup_methods, speedups, color=speedup_colors, alpha=0.7, edgecolor='black')
        ax.set_ylabel("Speedup (FLOPs)", fontsize=12)
        ax.set_title(f"{title} - Speedup Analysis", fontsize=14)
        plt.xticks(rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')

        # Add horizontal line at y=1 for reference
        ax.axhline(y=1, color='red', linestyle='--', alpha=0.7, linewidth=1)

        # Increase ylim by small margin
        bar_chart_margin_multiplier = 5
        ymin, ymax = ax.get_ylim()
        margin = (ymax - ymin) * 0.05 * bar_chart_margin_multiplier
        ax.set_ylim(max(0, ymin - margin), ymax + margin)

        # Add value labels on bars
        for bar, value in zip(bars, speedups):
            ax.annotate(f'{value:.1f}x',
                        xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                        xytext=(0, 5),
                        textcoords='offset points',
                        ha='center', va='bottom', fontsize=10)

        plt.tight_layout()
        return fig


class AbsoluteBarChartPlotter(BarChartPlotter):
    """Handles bar chart plotting with absolute values instead of baseline-subtracted values."""

    def plot_asr_delta(self, frontier_data: dict, baseline_frontier_data: dict|None,
                      baseline_max_asr: float, sample_levels_to_plot: tuple,
                      threshold: float|None, title: str = "Absolute ASR"):
        """Plot absolute ASR comparison bar chart."""
        methods = []
        max_asrs = []
        colors = []

        # Calculate bar chart margin multiplier like original
        bar_chart_margin_multiplier = 5

        # Add baseline (use absolute value)
        if baseline_frontier_data is not None:
            methods.append("Baseline")
            max_asrs.append(baseline_max_asr)  # Absolute value instead of delta
            colors.append("red")

        # Add sampling methods (use absolute values)
        for j in sample_levels_to_plot:
            if j in frontier_data:
                methods.append(f"{j} samples" if j != 1 else "1 sample")
                max_asrs.append(frontier_data[j]['max_asr'])  # Absolute value instead of delta
                colors.append(frontier_data[j]['color'])

        if not methods:
            return None

        # Create figure
        fig, ax = plt.subplots(figsize=(6, 4))

        bars = ax.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black')
        if threshold is None:
            ax.set_ylabel(r"$\mathcal{H}_q$", fontsize=17)  # No delta symbol
        else:
            ax.set_ylabel(r"$\text{ASR}_q$", fontsize=17)  # No delta symbol
        ax.set_title(f"{title}", fontsize=14)
        plt.xticks(rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')

        # Increase ylim by margin
        ymin, ymax = ax.get_ylim()
        margin = (ymax - ymin) * 0.03 * bar_chart_margin_multiplier
        ax.set_ylim(ymin - margin, ymax + margin)

        # Add value labels on bars
        for bar, value in zip(bars, max_asrs):
            offset_pt = 4
            va = 'bottom' if value >= 0 else 'top'
            offset = (0, offset_pt if value >= 0 else -offset_pt)
            ax.annotate(f'{value:.2f}',
                        xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                        xytext=offset,
                        textcoords='offset points',
                        ha='center', va=va, fontsize=10)

        plt.tight_layout()
        return fig