"""Comparative plotting functionality for multi-attack analysis."""

import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Dict, Set
from matplotlib.patches import Patch

from plot_base import FrontierPlotter, PlotterMixin, get_schedule_points
from config import GROUP_BY, ATTACKS
from data_processor import fetch_data
from src.io_utils import num_model_params


class ComparativeParetoPlotter(FrontierPlotter, PlotterMixin):
    """Handles comparative Pareto plotting across multiple attacks."""

    def plot(self, model: str, model_title: str, attacks_data: Dict,
             title: str = "Comparative Pareto Analysis",
             baseline_attacks: Optional[Set[str]] = None, schedule: str = 'end', **kwargs):
        """
        Generate comparative Pareto plot for multiple attacks.

        Parameters
        ----------
        model : str
            Model identifier
        model_title : str
            Model title for display
        attacks_data : dict
            Dictionary mapping attack names to (results_dict, config) tuples
        title : str
            Plot title
        baseline_attacks : set, optional
            Set of attack names to treat as baselines
        **kwargs
            Additional plotting parameters
        """
        # Extract parameters
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        threshold = kwargs.get('threshold', None)
        frontier_method = kwargs.get('frontier_method', 'basic')

        if not attacks_data:
            raise ValueError("attacks_data cannot be empty")

        # Store data for baseline method access
        self.current_attacks_data = attacks_data
        self.current_metric = metric
        self.current_threshold = threshold

        # Create figure
        fig, ax = self._setup_figure(title, figsize=(12, 8))

        # Color palette for different attacks
        colors = plt.cm.tab10(np.linspace(0, 1, len(attacks_data)))

        all_frontiers = []
        legend_elements = []

        for i, (config_key, (sampled_data, config)) in enumerate(attacks_data.items()):
            try:
                # Preprocess data
                y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
                    sampled_data, metric, threshold
                )

                # Generate points
                pts = get_schedule_points(
                    y, flops_optimization, flops_sampling_prefill_cache,
                    flops_sampling_generation, return_ratio=False, cumulative=config.get('cumulative', False), schedule_type=schedule
                )

                if not pts:
                    continue

                pts_array = np.array(pts)
                cost, step_idx, n_samp, mean_p = pts_array.T

                # Determine if this is a baseline attack
                is_baseline = baseline_attacks and any(
                    baseline_name in config_key.lower()
                    for baseline_name in baseline_attacks
                )

                # Plot points
                color = colors[i]
                alpha = 0.3 if is_baseline else 0.6
                marker = 'x' if is_baseline else 'o'
                size = 30 if is_baseline else 20

                ax.scatter(cost, mean_p, c=[color], alpha=alpha,
                          marker=marker, s=size, label=f'{config_key} (points)')

                # Compute and plot frontier
                if len(cost) > 1:
                    frontier_x, frontier_y = self._compute_frontier(
                        cost, mean_p, method=frontier_method
                    )

                    # Store frontier for envelope calculation
                    all_frontiers.append((frontier_x, frontier_y, config_key, color))

                    # Plot frontier line
                    line_style = '--' if is_baseline else '-'
                    line_width = 1 if is_baseline else 2

                    ax.plot(frontier_x, frontier_y, color=color, linestyle=line_style,
                           linewidth=line_width, label=f'{config_key} frontier')

                    # Add to legend
                    legend_elements.append(
                        Patch(facecolor=color, alpha=0.7, label=config_key)
                    )

            except Exception as e:
                print(f"Warning: Could not process {config_key}: {e}")
                continue

        # Add baseline points for specified attacks
        self._add_baseline_points(ax, model, baseline_attacks, colors[:len(attacks_data)])

        # Plot envelope of all frontiers if requested
        plot_envelope = kwargs.get('plot_envelope', False)
        if plot_envelope and len(all_frontiers) > 1:
            self._plot_envelope(ax, all_frontiers)

        # Setup axes
        self._setup_axes(ax)
        ax.set_xscale('log')

        # Custom legend
        ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')

        plt.tight_layout()
        return fig, ax

    def _plot_envelope(self, ax, all_frontiers):
        """Plot envelope of all frontiers."""
        # Find global x-range
        all_x = np.concatenate([fx for fx, _, _, _ in all_frontiers])
        x_min, x_max = all_x.min(), all_x.max()

        # Create interpolation grid
        x_interp = np.logspace(np.log10(x_min), np.log10(x_max), 1000)
        envelope_y = np.zeros_like(x_interp)

        # For each x point, find maximum y value across all frontiers
        for xi, x_val in enumerate(x_interp):
            max_y = 0
            for fx, fy, _, _ in all_frontiers:
                # Interpolate frontier at this x value
                if len(fx) > 1 and x_val >= fx.min() and x_val <= fx.max():
                    y_interp = np.interp(x_val, fx, fy)
                    max_y = max(max_y, y_interp)
            envelope_y[xi] = max_y

        # Plot envelope
        ax.fill_between(x_interp, envelope_y, alpha=0.2, color='gray',
                       label='Envelope', zorder=0)

    def _add_baseline_points(self, ax, model: str, baseline_attacks: Optional[Set[str]], colors):
        """Add baseline points for specified attacks."""
        if not baseline_attacks:
            return

        color_map = {}
        for i, (config_key, _) in enumerate(self.current_attacks_data.items()):
            # Find the original attack name for this config
            for atk_name, atk_cfg in ATTACKS:
                if atk_cfg.get('title_suffix') == config_key:
                    color_map[atk_name] = colors[i]
                    break

        for baseline_attack_name in baseline_attacks:
            try:
                # Find the config for this attack
                matching_config = None
                for atk_name, cfg in ATTACKS:
                    if atk_name == baseline_attack_name:
                        matching_config = cfg
                        break

                if not matching_config:
                    continue

                # Fetch baseline data for this attack
                baseline_params = matching_config.get("baseline_params", lambda: {
                    "generation_config": {"num_return_sequences": 1, "temperature": 0.0}
                })()
                baseline_attack = matching_config.get("baseline_attack", baseline_attack_name)
                baseline_data = fetch_data(model, baseline_attack, baseline_params,
                                         list(range(100)), GROUP_BY)

                # Process baseline data
                y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = self._preprocess_data(
                    baseline_data, self.current_metric, self.current_threshold
                )

                if y_baseline is not None:
                    n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape

                    # Get the point at max step count (last step)
                    pts = get_schedule_points(
                        y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache,
                        baseline_flops_sampling_generation, return_ratio=False,
                        cumulative=matching_config.get("cumulative", False), schedule_type='end'
                    )

                    if pts:
                        pts_array = np.array(pts)
                        cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts_array.T

                        # Find the point at max step count
                        max_step_mask = step_idx_baseline == (n_steps_baseline - 1)
                        if np.any(max_step_mask):
                            baseline_cost = cost_baseline[max_step_mask][0]
                            baseline_mean_p = mean_p_baseline[max_step_mask][0]

                            # Plot baseline point
                            color = color_map.get(baseline_attack_name, "black")
                            ax.scatter(baseline_cost, baseline_mean_p,
                                     s=100, marker="^", color=color,
                                     edgecolors='black', linewidth=1.5, alpha=0.9,
                                     label=f"{baseline_attack_name} Baseline", zorder=10)

            except Exception as e:
                print(f"Warning: Could not add baseline point for {baseline_attack_name}: {e}")
                continue


class MultiAttackNonCumulativePlotter(ComparativeParetoPlotter):
    """Handles multi-attack non-cumulative Pareto plotting with independent step analysis."""

    def plot(self, model: str, model_title: str, attacks_data: Dict,
             title: str = "Multi-Attack Non-Cumulative Analysis", **kwargs):
        """
        Generate multi-attack non-cumulative plot with independent step analysis.

        This implementation looks at each step independently and computes ASR/harm
        at that specific step, regardless of the sampling schedule.
        """
        # Extract parameters
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        threshold = kwargs.get('threshold', None)

        if not attacks_data:
            raise ValueError("attacks_data cannot be empty")

        # Create figure
        fig, ax = self._setup_figure(title, figsize=(12, 8))

        # Color palette for different attacks
        colors = plt.cm.tab10(np.linspace(0, 1, len(attacks_data)))

        legend_elements = []

        # Extract target_samples from kwargs
        target_samples = kwargs.get('target_samples', 50)

        for i, (config_key, (sampled_data, config)) in enumerate(attacks_data.items()):
            try:
                # Preprocess data - we need FLOPs for subsample_and_aggregate_n
                y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
                    sampled_data, metric, threshold
                )

                # For independent step analysis, look at each step separately
                pts = self._get_independent_step_points(
                    y, flops_optimization, flops_sampling_prefill_cache,
                    flops_sampling_generation, target_samples
                )

                if not pts:
                    continue

                pts_array = np.array(pts)
                step_percentage, step_idx, n_samp, mean_p = pts_array.T

                # Calculate delta from first step's value (like reference implementation)
                if len(mean_p) > 0:
                    first_value = mean_p[0]  # First step value as baseline
                    y_delta = mean_p - first_value

                    # Plot line connecting step points with delta values
                    color = colors[i]
                    if len(step_percentage) > 1:
                        # Sort by step percentage to ensure proper line connection
                        sorted_indices = np.argsort(step_percentage)
                        sorted_x = step_percentage[sorted_indices]
                        sorted_y_delta = y_delta[sorted_indices]

                        # Plot line connecting the points
                        ax.plot(sorted_x, sorted_y_delta, color=color, linestyle='-',
                               linewidth=1.2, label=config_key)

                        # Add to legend
                        legend_elements.append(
                            Patch(facecolor=color, alpha=0.7, label=config_key)
                        )
                    else:
                        # Single point case - delta would be 0
                        ax.scatter([step_percentage[0]], [0], c=[color], s=50,
                                 marker='o', label=config_key)

            except Exception as e:
                print(f"Warning: Could not process {config_key}: {e}")
                continue

        # No envelope plotting needed for independent step analysis

        # Setup axes for step percentages
        ax.set_xlabel(r"Optimization Progress (\%)", fontsize=15)
        if threshold is None:
            ax.set_ylabel(r"$\Delta$ $\mathcal{H}$", fontsize=16)
        else:
            ax.set_ylabel(r"$\Delta$ $ASR$", fontsize=16)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, 100)
        # Add reference line at delta=0
        ax.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=1)

        # Custom legend
        ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')

        plt.tight_layout()
        return fig, ax

    def _get_independent_step_points(self, y: np.ndarray, flops_optimization: np.ndarray,
                                   flops_sampling_prefill_cache: np.ndarray,
                                   flops_sampling_generation: np.ndarray, target_samples: int = 50):
        """
        Get points for independent step analysis using subsample_and_aggregate_n.

        For each step, compute the ASR/harm at that step using target_samples.
        This is independent of any sampling schedule.

        Parameters
        ----------
        y : np.ndarray
            Shape (n_runs, n_steps, n_samples)
        flops_optimization : np.ndarray
            Shape (n_runs, n_steps)
        flops_sampling_prefill_cache : np.ndarray
            Shape (n_runs, n_steps)
        flops_sampling_generation : np.ndarray
            Shape (n_runs, n_steps)
        target_samples : int
            Number of samples to use per step

        Returns
        -------
        list
            List of (step_percentage, step_idx, n_samples, mean_value) tuples
        """
        from data_processor import subsample_and_aggregate_n

        n_runs, n_steps, n_total_samples = y.shape
        pts = []
        rng = np.random.default_rng(42)

        # Limit target_samples to available samples
        actual_target_samples = min(target_samples, n_total_samples)

        for step_idx in range(n_steps):
            # Create n-schedule: only sample at this specific step
            n_vec = np.zeros(step_idx + 1, dtype=int)
            n_vec[step_idx] = actual_target_samples

            # Use subsample_and_aggregate_n to get the value at this step
            _, _, _, mean_value = subsample_and_aggregate_n(
                n_vec, y, flops_optimization, flops_sampling_prefill_cache,
                flops_sampling_generation, rng, return_ratio=False, n_smoothing=1
            )

            # Convert step index to percentage
            step_percentage = (step_idx / (n_steps - 1)) * 100 if n_steps > 1 else 0

            # Add point (step_percentage, step_idx, actual_target_samples, mean_value)
            pts.append((step_percentage, step_idx, actual_target_samples, mean_value))

        return pts

    def _plot_step_envelope(self, ax, all_frontiers):
        """Plot envelope of all frontiers for step-based analysis."""
        # Find global step percentage range
        all_x = np.concatenate([fx for fx, _, _, _ in all_frontiers])
        x_min, x_max = max(0, all_x.min()), min(100, all_x.max())

        # Create interpolation grid for step percentages
        x_interp = np.linspace(x_min, x_max, 1000)
        envelope_y = np.zeros_like(x_interp)

        # For each x point, find maximum y value across all frontiers
        for xi, x_val in enumerate(x_interp):
            max_y = 0
            for fx, fy, _, _ in all_frontiers:
                # Interpolate frontier at this x value
                if len(fx) > 1 and x_val >= fx.min() and x_val <= fx.max():
                    y_interp = np.interp(x_val, fx, fy)
                    max_y = max(max_y, y_interp)
            envelope_y[xi] = max_y

        # Plot envelope
        ax.fill_between(x_interp, envelope_y, alpha=0.2, color='gray',
                       label='Envelope', zorder=0)


class OptimizationProgressPlotter(FrontierPlotter, PlotterMixin):
    """Handles optimization progress plotting."""

    def plot(self, results: dict, title: str = "Optimization Progress Analysis",
             **kwargs):
        """
        Generate optimization progress plot.

        Parameters
        ----------
        results : dict
            Results data dictionary
        title : str
            Plot title
        **kwargs
            Additional plotting parameters
        """
        # Extract parameters
        metric = kwargs.get('metric', ('scores', 'strong_reject', 'p_harmful'))
        threshold = kwargs.get('threshold', None)


        self._validate_inputs(results)

        # Preprocess data
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = self._preprocess_data(
            results, metric, threshold
        )

        n_runs, n_steps, n_samples = y.shape

        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        fig.suptitle(title)

        # Plot 1: Progress by step (mean performance over steps)
        step_means = []
        step_stds = []

        for step in range(n_steps):
            step_data = y[:, step, :].max(axis=1)  # Max over samples for each run
            step_means.append(np.mean(step_data))
            step_stds.append(np.std(step_data))

        steps = np.arange(n_steps)
        ax1.plot(steps, step_means, 'b-', linewidth=2, label='Mean Performance')
        ax1.fill_between(steps,
                        np.array(step_means) - np.array(step_stds),
                        np.array(step_means) + np.array(step_stds),
                        alpha=0.3, color='blue')

        ax1.set_xlabel('Optimization Step')
        ax1.set_ylabel(r"$\mathcal{H}$")
        ax1.set_title('Performance vs Optimization Steps')
        ax1.grid(True, alpha=0.3)
        ax1.legend()

        # Plot 2: Sample efficiency (performance vs total samples used)
        sample_counts = np.arange(1, n_samples + 1)
        final_step_performance = []

        for n_samp in sample_counts:
            # Use best n_samp samples from final step
            final_step_data = y[:, -1, :n_samp].max(axis=1)  # Max over first n_samp samples
            final_step_performance.append(np.mean(final_step_data))

        ax2.semilogx(sample_counts, final_step_performance, 'r-', linewidth=2)
        ax2.set_xlabel('Number of Samples Used')
        ax2.set_ylabel(r"$\mathcal{H}$")
        ax2.set_title('Sample Efficiency (Final Step)')
        ax2.grid(True, alpha=0.3)

        return fig, (ax1, ax2)