#! /usr/bin/env python

from __future__ import annotations

from typing import Iterable, Optional, Dict, List, Tuple


import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import wandb
from rich import print

from experiments.exp_stats import analyze_area_under_curve


class Visualizer:
    r"""
    A class to visualize experiment results with common configuration across multiple plotting methods.
    """
    
    def __init__(
        self,
        entity: str,
        problem_name: str,
        methods_list: list[str],
        seed_list: Iterable[int],
        api,
        *,
        project: str = "rescue_HPOXGBoost",
        save_runs_dir: Optional[str] = "wandb_runs_data",
        method_colors: Optional[Dict[str, str]] = None,
        save_results_csv: bool = False,
        results_dir: Optional[str] = None,
    ):
        r"""
        Initialize the Visualizer with common parameters.
        
        Args:
            problem_name: Name of the problem (used for filtering runs and CSV filename).
            methods_list: List of method names to compare.
            seed_list: List of random seeds to aggregate over.
            api: wandb.Api instance.
            entity: W&B entity name.
            project: W&B project name.
            save_runs_dir: Directory to cache run history data.
            method_colors: Dictionary mapping method names to colors.
            save_results_csv: If True, save results to CSV file.
            results_dir: Directory to save results CSV. Defaults to rescue/experiments/results.
        """
        self.problem_name = problem_name
        self.methods_list = methods_list
        self.seed_list = seed_list
        self.api = api
        self.entity = entity
        self.project = project
        self.save_runs_dir = save_runs_dir
        self.method_colors = method_colors
        self.save_results_csv = save_results_csv
        self.results_dir = results_dir
    
    def _load_or_build_full_history(self) -> pd.DataFrame:
        r"""
        Load project_full_history.csv from cache_dir (or CWD).
        If missing, build it by concatenating run.history() for all runs.
        """
        entity_project = f"{self.entity}/{self.project}"
        if self.save_runs_dir is not None:
            os.makedirs(self.save_runs_dir, exist_ok=True)
            csv_path = os.path.join(self.save_runs_dir, f"{self.problem_name}-{self.project}.csv")
        else:
            csv_path = f"{self.problem_name}-{self.project}.csv"

        if os.path.exists(csv_path):
            return pd.read_csv(csv_path)

        # Build full history from the project
        runs = self.api.runs(entity_project)
        frames = []
        for run in runs:
            df = run.history()  # full timeline for this run
            if df is None or df.empty:
                continue
            df["run_id"] = run.id
            df["run_name"] = run.name  # e.g., "{problem}_{method}_run.{seed}"
            frames.append(df)

        if not frames:
            raise ValueError(f"No histories found in project '{self.project}'.")

        full = pd.concat(frames, ignore_index=True)
        full.to_csv(csv_path, index=False)
        return full

    def plot_regret_vs_cost(
        self,
        *,
        n_grid: Optional[int] = None,
        cost_key: str = "cost",
        regret_key: str = "best_nsga2_regret",
        budget_key: str = "budget",
        budget: Optional[float] = None,  # Target budget to clip all runs to
        title: Optional[str] = None,
        use_sem_shading: bool = False,
        normalized_bounds: Optional[Tuple[float, float]] = None,  # e.g., (1.0, 15.0)
        step_interval: Optional[int] = None,  # e.g., 5 to plot every 5th point
    ) -> Dict[str, Dict[str, np.ndarray]]:
        """
        Plot regret vs COST for multiple methods with flexible budget handling.

        Now loads data from a single project_full_history.csv (built if absent)
        and filters by run_name == "{problem}_{method}_run.{seed}".
        
        Args:
            n_grid: Number of points in the interpolation grid. If None, automatically set to common_budget.
            cost_key: Column name for cost values.
            regret_key: Column name for regret values.
            budget_key: Column name for budget values.
            budget: If provided, clip all runs to this target budget instead of requiring
                    strict budget equality. If None, enforces strict budget equality.
            title: Custom plot title.
            use_sem_shading: If True, use SEM for error bars; otherwise use std.
            normalized_bounds: If provided, normalize cost axis to (low, high).
            step_interval: If provided, only plot every Nth point (e.g., 5 for every 5th point).
                          If None, plot all points with full resolution.
        
        Returns:
            Dictionary of results for each method containing grid, mean, std, sem, and per_run data.
        """
        # ---------- 0) Load the single full-history CSV ----------
        full_history_df = self._load_or_build_full_history()

        # ---------- 1) First pass: identify invalid seeds across all methods ----------
        invalid_seeds = set()
        discarded_runs: List[str] = []

        for method in self.methods_list:
            for seed in self.seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                if df.empty:
                    raise ValueError(f"No rows found for run '{run_name}' in full history.")

                # Required columns
                missing = [k for k in (cost_key, regret_key, budget_key) if k not in df.columns]
                if missing:
                    raise ValueError(
                        f"Missing columns {missing} in {run_name}"
                    )

                # Extract arrays for validation
                costs = df[cost_key].to_numpy(dtype=float)
                
                # Budget validation
                bud_vals = df[budget_key].to_numpy(dtype=float)
                bud_vals = bud_vals[~np.isnan(bud_vals)]
                if bud_vals.size == 0:
                    raise ValueError(f"No finite '{budget_key}' in {run_name}")
                unique_b = np.unique(bud_vals)
                if unique_b.size != 1:
                    raise ValueError(
                        f"Non-constant '{budget_key}' in {run_name}: {unique_b.tolist()}"
                    )
                run_budget = float(unique_b[0])

                # Use target budget if specified, otherwise use run's budget
                effective_budget = budget if budget is not None else run_budget
                
                # Check after budget clipping - only discard if NO points within effective budget
                m = costs <= effective_budget
                costs_clipped = costs[m]
                if costs_clipped.size == 0:
                    if seed not in invalid_seeds:
                        print(f"Warning: Seed {seed} has 0 points ≤ effective budget {effective_budget} in {method}; excluding seed {seed} from ALL methods.")
                        invalid_seeds.add(seed)
                        discarded_runs.append(f"Seed {seed} (0 points ≤ effective budget {effective_budget} in {method})")
                    continue
                
                # If target budget is specified, check that this run has actual data up to that budget
                if budget is not None:
                    max_cost_in_run = costs.max()
                    if max_cost_in_run < budget:
                        raise ValueError(
                            f"Run '{run_name}' has maximum cost {max_cost_in_run:.2f} which is less than "
                            f"specified target budget {budget:.2f}. All runs must have actual data up to "
                            f"the target budget. Consider using a lower budget (≤ {max_cost_in_run:.2f}) "
                            f"or ensure all runs have data up to budget {budget:.2f}."
                        )

        # ---------- 2) Second pass: load valid runs only ----------
        valid_seed_list = [seed for seed in self.seed_list if seed not in invalid_seeds]
        all_payloads: Dict[str, List[dict]] = {m: [] for m in self.methods_list}
        all_budgets: List[float] = []

        for method in self.methods_list:
            for seed in valid_seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                # Seed consistency check
                if "seed" in df.columns and not df["seed"].isna().all():
                    seed_in_df = int(pd.to_numeric(df["seed"].dropna().iloc[0]))
                else:
                    try:
                        seed_in_df = int(run_name.rsplit("run.", 1)[1])
                    except Exception:
                        raise ValueError(f"Cannot determine seed for '{run_name}'.")

                if seed != seed_in_df:
                    raise ValueError(
                        f"Seed mismatch in {run_name}: expected {seed}, got {seed_in_df}"
                    )

                # Extract arrays
                costs   = df[cost_key].to_numpy(dtype=float)
                regrets = df[regret_key].to_numpy(dtype=float)

                # Budget validation (already done in first pass, but keep for safety)
                bud_vals = df[budget_key].to_numpy(dtype=float)
                bud_vals = bud_vals[~np.isnan(bud_vals)]
                unique_b = np.unique(bud_vals)
                run_budget = float(unique_b[0])

                # Other validations that still raise errors
                if not np.all(np.diff(costs) >= 0):
                    raise ValueError(f"{run_name}: non-monotone costs.")

                # Use target budget if specified, otherwise use run's budget
                effective_budget = budget if budget is not None else run_budget
                
                # Clip within effective budget
                m = costs <= effective_budget
                costs, regrets = costs[m], regrets[m]

                # Pad insufficient data within budget (we know there's at least 1 point)
                if costs.size == 1:
                    print(f"Warning: {run_name}: only 1 point ≤ effective budget {effective_budget}; padding with that value.")
                    # Extend to effective budget with the same regret value
                    costs = np.array([costs[0], effective_budget])
                    regrets = np.array([regrets[0], regrets[0]])

                all_payloads[method].append(
                    {"seed": seed, "costs": costs, "regrets": regrets, "budget": effective_budget}
                )
                all_budgets.append(run_budget)

        # Print summary of discarded seeds
        if invalid_seeds:
            print(f"\nDiscarded seeds {sorted(invalid_seeds)} from ALL methods due to insufficient data points:")
            for run_info in discarded_runs:
                print(f"  - {run_info}")
            print(f"Using {len(valid_seed_list)}/{len(self.seed_list)} seeds for comparison.")

        if len(all_budgets) == 0:
            raise ValueError("No runs loaded; cannot determine common budget.")

        # ---------- 2) Determine common budget ----------
        if budget is not None:
            # Use specified target budget
            common_budget = float(budget)
            print(f"Using specified target budget: {common_budget}")
        else:
            # Enforce EXACT same budget across all methods & runs (original behavior)
            unique_budgets = {b for b in all_budgets}
            if len(unique_budgets) != 1:
                details = []
                for method in self.methods_list:
                    for rp in all_payloads[method]:
                        details.append(
                            f"{self.problem_name}_{method}_run.{rp['seed']}: budget={rp['budget']!r}"
                        )
                detail_str = "\n  ".join(details)
                raise ValueError(
                    "Budget mismatch across methods/runs. Budgets must be EXACTLY equal.\n"
                    f"Unique budgets observed: {sorted(unique_budgets)!r}\n"
                    f"Per-run details:\n  {detail_str}\n"
                    f"Consider specifying a target budget parameter to clip all runs to a common budget."
                )
            common_budget = float(next(iter(unique_budgets)))
        
        # Automatically set n_grid to common_budget if n_grid is None
        if n_grid is None:
            n_grid = int(common_budget)
        
        grid = np.linspace(1.0, common_budget, n_grid)

        # If normalized bounds are provided, validate and prepare x for plotting
        if normalized_bounds is not None:
            low, high = normalized_bounds
            if not (isinstance(low, (int, float)) and isinstance(high, (int, float))):
                raise ValueError("normalized_bounds must be a pair of numbers (low, high).")
            if not (low < high):
                raise ValueError("normalized_bounds must satisfy low < high.")
            x_plot = low + (grid / common_budget) * (high - low)
            xlabel = f"Cost (normalized to [{low}, {high}])"
        else:
            x_plot = grid
            xlabel = "Cost"

        # ---------- 3) Interpolate per run to the common grid, aggregate ----------
        results: Dict[str, Dict[str, np.ndarray]] = {}
        plt.figure(figsize=(7, 4.5))

        for method in self.methods_list:
            if not all_payloads[method]:
                raise ValueError(f"No runs found for method '{method}'.")

            per_run = []
            for rp in all_payloads[method]:
                # Apply step interval to actual data points if specified
                if step_interval is not None:
                    indices = np.arange(0, len(rp["costs"]), step_interval)
                    # Always include the last point to ensure we show the full range
                    if indices[-1] != len(rp["costs"]) - 1:
                        indices = np.append(indices, len(rp["costs"]) - 1)
                    costs_sub = rp["costs"][indices]
                    regrets_sub = rp["regrets"][indices]
                else:
                    costs_sub = rp["costs"]
                    regrets_sub = rp["regrets"]
                
                f = interp1d(
                    costs_sub,
                    regrets_sub,
                    kind="linear",
                    bounds_error=False,
                    fill_value=(regrets_sub[0], regrets_sub[-1]),
                    assume_sorted=True,
                )
                per_run.append(f(grid))

            arr  = np.vstack(per_run)  # [num_runs, n_grid]
            mean = arr.mean(axis=0)
            std  = arr.std(axis=0)
            if arr.shape[0] == 1:
                sem = np.zeros_like(std)
            else:
                sem = std / np.sqrt(arr.shape[0])

            spread = sem if use_sem_shading else std
            spread_label = "SEM" if use_sem_shading else "Std. Dev."

            # Get color for this method if specified
            color = self.method_colors.get(method) if self.method_colors else None

            plt.plot(x_plot, mean, label=f"{method} (mean)", color=color)
            plt.fill_between(x_plot, mean - spread, mean + spread, alpha=0.1,
                             label=f"{method} ± {spread_label}", color=color)

            results[method] = {
                "grid": grid,           # original cost grid
                "x_plot": x_plot,       # plotted x (same as grid, or normalized)
                "mean": mean,
                "std": std,
                "sem": sem,
                "per_run": arr,
                "budget": common_budget,
            }

        # ---------- 4) Final formatting ----------
        plt.xlabel(xlabel)
        plt.ylabel("Regret")
        ttl = title or f"{self.project}: {regret_key} vs. Cost"
        if normalized_bounds is not None:
            ttl += f" (normalized bounds = [{normalized_bounds[0]}, {normalized_bounds[1]}])"
        plt.title(ttl)
        plt.grid(alpha=0.3)
        plt.legend(ncol=2, fontsize=9)
        plt.tight_layout()
        plt.show()

        # ---------- 5) Save results to CSV if requested ----------
        if self.save_results_csv:
            results_dir = self.results_dir
            if results_dir is None:
                # Default to rescue/experiments/results/data
                script_dir = os.path.dirname(os.path.abspath(__file__))
                results_dir = os.path.join(script_dir, "results", "data")
            
            os.makedirs(results_dir, exist_ok=True)
            csv_filename = f"{self.problem_name}_{regret_key}.csv"
            csv_path = os.path.join(results_dir, csv_filename)
            
            # Build a DataFrame with columns: cost, method1_mean, method1_std, method1_sem, method2_mean, ...
            df_data = {"cost": grid}
            if normalized_bounds is not None:
                df_data["cost_normalized"] = x_plot
            
            for method in self.methods_list:
                df_data[f"{method}_mean"] = results[method]["mean"]
                df_data[f"{method}_std"] = results[method]["std"]
                df_data[f"{method}_sem"] = results[method]["sem"]
            
            df = pd.DataFrame(df_data)
            df.to_csv(csv_path, index=False)
            print(f"Results saved to: {csv_path}")

        return results

    def plot_fidelity_and_iterations(
        self,
        *,
        fidelity_key: str = "new_fidelity",
        cost_key: str = "cost",
        budget_key: str = "budget",
        budget: Optional[float] = None,
        title: Optional[str] = None,
        colormap: str = "viridis",
        scatter_size: int = 15,
        iteration_color: str = "#09F0BE",
        discrete_fidelities: Optional[List[float]] = None,
    ) -> Dict[str, Dict[str, any]]:
        r"""
        Analyze fidelity distribution and iteration counts for each method.
        
        Args:
            fidelity_key: Column name for fidelity values.
            cost_key: Column name for cost values (to clip to budget).
            budget_key: Column name for budget values.
            budget: If provided, clip all runs to this target budget. If None, use each run's budget.
            title: Custom plot title.
            colormap: Matplotlib colormap name for fidelity gradient (default: 'viridis').
            scatter_size: Size of scatter points (default: 25).
            iteration_color: Color for iteration line and right y-axis (default: '#E67E22' - orange).
            discrete_fidelities: Optional list of discrete fidelity values. If provided, 
                fidelity distribution will be displayed as counts per fidelity level instead of continuous distribution.
        
        Returns:
            Dictionary of results for each method containing iteration counts and fidelity statistics.
        """
        # Load the full history CSV
        full_history_df = self._load_or_build_full_history()

        # First pass: identify invalid seeds (same logic as plot_regret_vs_cost)
        invalid_seeds = set()
        discarded_runs: List[str] = []

        for method in self.methods_list:
            for seed in self.seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                if df.empty:
                    raise ValueError(f"No rows found for run '{run_name}' in full history.")

                # Required columns
                missing = [k for k in (cost_key, budget_key) if k not in df.columns]
                if missing:
                    raise ValueError(f"Missing columns {missing} in {run_name}")

                costs = df[cost_key].to_numpy(dtype=float)
                
                # Budget validation
                bud_vals = df[budget_key].to_numpy(dtype=float)
                bud_vals = bud_vals[~np.isnan(bud_vals)]
                if bud_vals.size == 0:
                    raise ValueError(f"No finite '{budget_key}' in {run_name}")
                unique_b = np.unique(bud_vals)
                if unique_b.size != 1:
                    raise ValueError(f"Non-constant '{budget_key}' in {run_name}: {unique_b.tolist()}")
                run_budget = float(unique_b[0])

                effective_budget = budget if budget is not None else run_budget
                
                # Check after budget clipping
                m = costs <= effective_budget
                costs_clipped = costs[m]
                if costs_clipped.size == 0:
                    if seed not in invalid_seeds:
                        print(f"Warning: Seed {seed} has 0 points ≤ effective budget {effective_budget} in {method}; excluding seed {seed} from ALL methods.")
                        invalid_seeds.add(seed)
                        discarded_runs.append(f"Seed {seed} (0 points ≤ effective budget {effective_budget} in {method})")
                    continue

        # Second pass: collect valid data
        valid_seed_list = [seed for seed in self.seed_list if seed not in invalid_seeds]
        results: Dict[str, Dict[str, any]] = {}
        
        for method in self.methods_list:
            max_iterations_per_seed = []
            fidelity_values_all = []
            iteration_indices_all = []
            
            for seed in valid_seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                if df.empty:
                    continue

                # Get iteration from wandb data - count all rows
                if "iteration" not in df.columns:
                    raise ValueError(f"Column 'iteration' not found in {run_name}")
                
                # Count total number of iterations (rows in data)
                num_iterations = len(df)
                
                if num_iterations == 0:
                    raise ValueError(f"No iterations found in {run_name}")
                
                max_iterations_per_seed.append(num_iterations)
                
                # Get all fidelity values and their iteration indices
                if fidelity_key in df.columns:
                    fidelities = df[fidelity_key].to_numpy(dtype=float)
                    iterations = df["iteration"].to_numpy(dtype=float)
                    
                    # Filter out NaN values
                    valid_mask = ~np.isnan(fidelities)
                    fidelities = fidelities[valid_mask]
                    iterations = iterations[valid_mask]
                    
                    fidelity_values_all.extend(fidelities.tolist())
                    iteration_indices_all.extend(iterations.tolist())
                else:
                    # Single fidelity - assume 1.0 for all iterations
                    fidelity_values_all.extend([1.0] * num_iterations)
                    iteration_indices_all.extend(list(range(num_iterations)))
            
            if not max_iterations_per_seed:
                print(f"Warning: No valid runs found for method '{method}'. Skipping.")
                continue
            
            # Calculate statistics
            mean_iterations = np.mean(max_iterations_per_seed)
            std_iterations = np.std(max_iterations_per_seed)
            
            mean_fidelity = np.mean(fidelity_values_all)
            std_fidelity = np.std(fidelity_values_all)
            
            results[method] = {
                "iteration_mean": mean_iterations,
                "iteration_std": std_iterations,
                "iteration_per_seed": max_iterations_per_seed,
                "fidelity_mean": mean_fidelity,
                "fidelity_std": std_fidelity,
                "fidelity_values": fidelity_values_all,
                "iteration_indices": iteration_indices_all,
            }
            
            print(f"{method}: Iterations = {mean_iterations:.2f} ± {std_iterations:.2f}, Mean Fidelity = {mean_fidelity:.3f} ± {std_fidelity:.3f}")

        # Print summary of discarded seeds
        if invalid_seeds:
            print(f"\nDiscarded seeds {sorted(invalid_seeds)} from ALL methods due to insufficient data points:")
            for run_info in discarded_runs:
                print(f"  - {run_info}")
            print(f"Using {len(valid_seed_list)}/{len(self.seed_list)} seeds for comparison.")

        # Determine if we're handling discrete fidelities
        use_discrete_fidelities = discrete_fidelities is not None
        if use_discrete_fidelities:
            discrete_fidelities = sorted(discrete_fidelities)
            print(f"Using discrete fidelities: {discrete_fidelities}")

        # Create plot with boxplot for fidelity and line for iterations (dual y-axis)
        fig, ax1 = plt.subplots(figsize=(12, 7))
        
        iteration_means = [results[m]["iteration_mean"] for m in self.methods_list]
        iteration_stds = [results[m]["iteration_std"] for m in self.methods_list]
        
        x_pos = np.arange(len(self.methods_list))
        
        if use_discrete_fidelities:
            # For discrete fidelities: create grouped bar chart showing counts per fidelity level
            from matplotlib.cm import ScalarMappable
            
            # Count occurrences of each discrete fidelity per method
            fidelity_counts = {}
            for method in self.methods_list:
                fid_vals = np.array(results[method]["fidelity_values"])
                counts = {fid: np.sum(np.abs(fid_vals - fid) < 1e-6) for fid in discrete_fidelities}
                fidelity_counts[method] = counts
            
            # Create grouped bar chart
            bar_width = 0.8 / len(discrete_fidelities)
            cmap_obj = plt.get_cmap(colormap)
            colors_discrete = [cmap_obj(i / (len(discrete_fidelities) - 1)) if len(discrete_fidelities) > 1 else cmap_obj(0.5) 
                             for i in range(len(discrete_fidelities))]
            
            for fid_idx, fid_val in enumerate(discrete_fidelities):
                positions = x_pos + (fid_idx - len(discrete_fidelities)/2 + 0.5) * bar_width
                counts = [fidelity_counts[m][fid_val] for m in self.methods_list]
                ax1.bar(positions, counts, bar_width, label=f'Fidelity {fid_val:.2f}',
                       color=colors_discrete[fid_idx], alpha=0.7, edgecolor='#34495E', linewidth=1.2)
            
            ax1.set_ylabel('Fidelity Query Counts', fontsize=14, fontweight='bold', color='#34495E')
            ax1.legend(title='Discrete Fidelities', loc='upper left', fontsize=10, framealpha=0.95, 
                      edgecolor='#34495E', fancybox=True, shadow=True)
        else:
            # Left y-axis: Boxplot for continuous fidelity distribution
            fidelity_data = [results[m]["fidelity_values"] for m in self.methods_list]
            
            # Create colormap for iteration-based coloring
            from matplotlib.cm import ScalarMappable
            
            # Find global min/max iterations for consistent colormap
            all_iterations = []
            for m in self.methods_list:
                all_iterations.extend(results[m]["iteration_indices"])
            iter_min, iter_max = np.min(all_iterations), np.max(all_iterations)
            
            cmap_obj = plt.get_cmap(colormap)
            norm = plt.Normalize(vmin=iter_min, vmax=iter_max)
            sm = ScalarMappable(cmap=cmap_obj, norm=norm)
            
            # Draw violin plot first
            parts = ax1.violinplot(fidelity_data, positions=x_pos, widths=0.5,
                                   showmeans=False, showmedians=True, showextrema=True)
            
            # Customize violin plot with elegant colors
            for pc in parts['bodies']:
                pc.set_facecolor('#E8F4F8')
                pc.set_edgecolor('#34495E')
                pc.set_alpha(0.5)
                pc.set_linewidth(2)
                pc.set_zorder(1)
            
            # Median line
            parts['cmedians'].set_edgecolor('#34495E')
            parts['cmedians'].set_linewidth(2.5)
            parts['cmedians'].set_zorder(2)
            
            # Extrema bars
            parts['cmaxes'].set_edgecolor('#34495E')
            parts['cmaxes'].set_linewidth(1.5)
            parts['cmins'].set_edgecolor('#34495E')
            parts['cmins'].set_linewidth(1.5)
            parts['cbars'].set_edgecolor('#34495E')
            parts['cbars'].set_linewidth(1.5)
            
            # Add scattered raw sample points colored by iteration (on top)
            for i, method in enumerate(self.methods_list):
                fid_vals = np.array(results[method]["fidelity_values"])
                iter_vals = np.array(results[method]["iteration_indices"])
                
                # Add jitter to x-position
                x_jitter = np.random.normal(x_pos[i], 0.04, size=len(fid_vals))
                
                # Color by iteration
                colors = cmap_obj(norm(iter_vals))
                ax1.scatter(x_jitter, fid_vals, s=scatter_size, c=colors, 
                           edgecolors='none', zorder=3)
            
            ax1.set_ylabel('Fidelity Distribution', fontsize=14, fontweight='bold', color='#34495E')
            ax1.set_ylim(-0.05, 1.05)
        
        # Format left y-axis
        ax1.set_xlabel('Method', fontsize=14, fontweight='bold')
        ax1.tick_params(axis='y', labelcolor='#34495E', labelsize=12)
        ax1.set_xticks(x_pos)
        ax1.set_xticklabels(self.methods_list, rotation=45, ha='right', fontsize=12)
        ax1.tick_params(axis='x', labelsize=12)
        ax1.grid(axis='y', alpha=0.2, linestyle='--', linewidth=0.8)
        ax1.spines['top'].set_visible(False)
        
        # Right y-axis: Line plot for iterations
        ax2 = ax1.twinx()
        ax2.plot(x_pos, iteration_means, marker='s', linewidth=3, markersize=10,
                color=iteration_color, markeredgecolor='white', markeredgewidth=1.5, zorder=5)
        
        # Darker shade for error bars
        import matplotlib.colors as mcolors
        error_color = mcolors.to_rgb(iteration_color)
        error_color = tuple(max(0, c * 0.7) for c in error_color)  # Darken by 30%
        
        ax2.errorbar(x_pos, iteration_means, yerr=iteration_stds,
                    fmt='none', ecolor=error_color, capsize=7, capthick=2.5, elinewidth=2.5, zorder=4)
        
        ax2.set_ylabel('Number of Iterations', fontsize=14, fontweight='bold', color=iteration_color)
        ax2.tick_params(axis='y', labelcolor=iteration_color, labelsize=12)
        ax2.spines['top'].set_visible(False)
        
        # Add colorbar for iteration (only for continuous fidelities)
        if not use_discrete_fidelities:
            cbar = plt.colorbar(sm, ax=ax2, fraction=0.03, pad=0.08)
            cbar.set_label('Iteration Index', fontsize=12, fontweight='bold')
            cbar.ax.tick_params(labelsize=11)
        
        # Add legend
        from matplotlib.lines import Line2D
        if use_discrete_fidelities:
            legend_elements = [
                Line2D([0], [0], marker='s', color=iteration_color, linewidth=3, 
                       markersize=10, markeredgecolor='white', markeredgewidth=1.5,
                       label='Mean iterations')
            ]
            ax2.legend(handles=legend_elements, loc='upper right', fontsize=11, framealpha=0.95, 
                      edgecolor='#34495E', fancybox=True, shadow=True)
        else:
            legend_elements = [
                Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', 
                       markeredgecolor='white', markersize=9, label='Fidelity queries'),
                Line2D([0], [0], marker='s', color=iteration_color, linewidth=3, 
                       markersize=10, markeredgecolor='white', markeredgewidth=1.5,
                       label='Mean iterations')
            ]
            ax1.legend(handles=legend_elements, loc='best', fontsize=11, framealpha=0.95, 
                      edgecolor='#34495E', fancybox=True, shadow=True)
        
        # Title
        ttl = title or f"{self.problem_name}: Iterations & Fidelity Distribution"
        ax1.set_title(ttl, fontsize=15, fontweight='bold', pad=15)
        
        plt.tight_layout()
        plt.show()
        
        # Save results to CSV if requested
        if self.save_results_csv:
            results_dir = self.results_dir
            if results_dir is None:
                script_dir = os.path.dirname(os.path.abspath(__file__))
                results_dir = os.path.join(script_dir, "results", "data")
            
            os.makedirs(results_dir, exist_ok=True)
            csv_filename = f"{self.problem_name}_fidelity_iterations.csv"
            csv_path = os.path.join(results_dir, csv_filename)
            
            # Build DataFrame for summary statistics
            df_data = {
                "method": [],
                "iteration_mean": [],
                "iteration_std": [],
                "fidelity_mean": [],
                "fidelity_std": [],
            }
            
            # Add per-seed iteration counts
            max_seeds = max(len(results[m]["iteration_per_seed"]) for m in self.methods_list)
            for i in range(max_seeds):
                df_data[f"iterations_seed_{i+1}"] = []
            
            for method in self.methods_list:
                df_data["method"].append(method)
                df_data["iteration_mean"].append(results[method]["iteration_mean"])
                df_data["iteration_std"].append(results[method]["iteration_std"])
                df_data["fidelity_mean"].append(results[method]["fidelity_mean"])
                df_data["fidelity_std"].append(results[method]["fidelity_std"])
                
                # Add per-seed values
                iter_per_seed = results[method]["iteration_per_seed"]
                for i in range(max_seeds):
                    if i < len(iter_per_seed):
                        df_data[f"iterations_seed_{i+1}"].append(iter_per_seed[i])
                    else:
                        df_data[f"iterations_seed_{i+1}"].append(np.nan)
            
            # Save everything in ONE CSV: raw fidelity values + iteration indices + iteration stats
            combined_data = []
            for method in self.methods_list:
                iter_mean = results[method]["iteration_mean"]
                iter_std = results[method]["iteration_std"]
                fid_vals = results[method]["fidelity_values"]
                iter_indices = results[method]["iteration_indices"]
                
                # Add all raw fidelity values with their iteration indices for this method
                for fid_val, iter_idx in zip(fid_vals, iter_indices):
                    combined_data.append({
                        "method": method,
                        "fidelity": fid_val,
                        "iteration": iter_idx,
                        "iteration_mean": iter_mean,
                        "iteration_std": iter_std,
                    })
            
            df_combined = pd.DataFrame(combined_data)
            df_combined.to_csv(csv_path, index=False)
            print(f"Fidelity and iteration data saved to single CSV: {csv_path}")
        
        return results

    def plot_violations_bar(
        self,
        *,
        violation_key: str = "curr_nsga2_violation",
        cost_key: str = "cost",
        budget_key: str = "budget",
        budget: Optional[float] = None,
        use_rate: bool = True,
        title: Optional[str] = None,
    ) -> Dict[str, Dict[str, float]]:
        r"""
        Plot total number of violations and standard deviation for each method as a bar plot.
        
        Args:
            violation_key: Column name for violation values (1 = violation, 0 = no violation).
            cost_key: Column name for cost values (to clip to budget).
            budget_key: Column name for budget values.
            budget: If provided, clip all runs to this target budget. If None, use each run's budget.
            use_rate: If True, compute violation rate (violations/iterations) instead of total count.
            title: Custom plot title.
        
        Returns:
            Dictionary of results for each method containing total violations and std.
        """
        # Load the full history CSV
        full_history_df = self._load_or_build_full_history()

        results: Dict[str, Dict[str, float]] = {}
        
        for method in self.methods_list:
            violation_counts_per_seed = []
            
            for seed in self.seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                if df.empty:
                    print(f"Warning: No rows found for run '{run_name}' in full history. Skipping.")
                    continue

                # Check if violation_key exists
                if violation_key not in df.columns:
                    print(f"Warning: Column '{violation_key}' not found in {run_name}. Skipping.")
                    continue

                # Get costs and violations
                costs = df[cost_key].to_numpy(dtype=float)
                violations = df[violation_key].to_numpy(dtype=float)
                
                # Remove NaN values
                valid_mask = ~(np.isnan(costs) | np.isnan(violations))
                costs = costs[valid_mask]
                violations = violations[valid_mask]
                
                # Get budget for this run if budget clipping is needed
                if budget is not None or cost_key in df.columns:
                    bud_vals = df[budget_key].to_numpy(dtype=float)
                    bud_vals = bud_vals[~np.isnan(bud_vals)]
                    if len(bud_vals) > 0:
                        unique_b = np.unique(bud_vals)
                        run_budget = float(unique_b[0])
                        
                        # Use target budget if specified, otherwise use run's budget
                        effective_budget = budget if budget is not None else run_budget
                        
                        # Clip within effective budget
                        m = costs <= effective_budget
                        violations = violations[m]
                
                if use_rate:
                    # Violation rate: violations per iteration
                    num_iterations = len(violations)
                    if num_iterations > 0:
                        violation_metric = np.sum(violations) / num_iterations
                    else:
                        violation_metric = 0.0
                else:
                    # Total count
                    violation_metric = int(np.sum(violations))
                
                violation_counts_per_seed.append(violation_metric)
            
            if not violation_counts_per_seed:
                print(f"Warning: No valid runs found for method '{method}'. Skipping.")
                continue
            
            # Calculate statistics across seeds
            violation_array = np.array(violation_counts_per_seed)
            mean_violations = violation_array.mean()
            std_violations = violation_array.std()
            
            results[method] = {
                "total_mean": mean_violations,
                "total_std": std_violations,
                "per_seed": violation_counts_per_seed,
            }
            
            print(f"{method}: Mean violations = {mean_violations:.2f} ± {std_violations:.2f}")

        # Create bar plot
        fig, ax = plt.subplots(figsize=(10, 6))
        
        x_pos = np.arange(len(self.methods_list))
        means = [results[m]["total_mean"] for m in self.methods_list]
        stds = [results[m]["total_std"] for m in self.methods_list]
        
        # Get colors for each method
        colors = [self.method_colors.get(m, None) if self.method_colors else None for m in self.methods_list]
        
        bars = ax.bar(x_pos, means, yerr=stds, capsize=5, alpha=0.7, 
                       color=colors, edgecolor='black', linewidth=1.2)
        
        ax.set_xlabel('Method', fontsize=12)
        ax.set_ylabel('Total Number of Violations', fontsize=12)
        ttl = title or f"{self.problem_name}: Total Violations per Method"
        ax.set_title(ttl, fontsize=14)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(self.methods_list, rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # ---------- Save results to CSV if requested ----------
        if self.save_results_csv:
            results_dir = self.results_dir
            if results_dir is None:
                # Default to rescue/experiments/results/data
                script_dir = os.path.dirname(os.path.abspath(__file__))
                results_dir = os.path.join(script_dir, "results", "data")
            
            os.makedirs(results_dir, exist_ok=True)
            csv_filename = f"{self.problem_name}_violations.csv"
            csv_path = os.path.join(results_dir, csv_filename)
            
            # Build a DataFrame with columns: method, mean_violations, std_violations, seed1, seed2, ...
            df_data = {"method": [], "mean_violations": [], "std_violations": []}
            
            # Add per-seed columns
            max_seeds = max(len(results[m]["per_seed"]) for m in self.methods_list)
            for i in range(max_seeds):
                df_data[f"seed_{i+1}"] = []
            
            for method in self.methods_list:
                df_data["method"].append(method)
                df_data["mean_violations"].append(results[method]["total_mean"])
                df_data["std_violations"].append(results[method]["total_std"])
                
                # Add per-seed values
                per_seed = results[method]["per_seed"]
                for i in range(max_seeds):
                    if i < len(per_seed):
                        df_data[f"seed_{i+1}"].append(per_seed[i])
                    else:
                        df_data[f"seed_{i+1}"].append(np.nan)
            
            df = pd.DataFrame(df_data)
            df.to_csv(csv_path, index=False)
            print(f"Violation results saved to: {csv_path}")
        
        return results

    def plot_runtime_boxplot(
        self,
        *,
        runtime_key: str = "_runtime",
        cost_key: str = "cost",
        budget_key: str = "budget",
        budget: Optional[float] = None,
        title: Optional[str] = None,
    ) -> Dict[str, Dict[str, any]]:
        r"""
        Plot runtime distribution for each method as a boxplot.
        
        Args:
            runtime_key: Column name for runtime values.
            cost_key: Column name for cost values (to clip to budget).
            budget_key: Column name for budget values.
            budget: If provided, clip all runs to this target budget. If None, use each run's budget.
            title: Custom plot title.
        
        Returns:
            Dictionary of results for each method containing runtime statistics.
        """
        # Load the full history CSV
        full_history_df = self._load_or_build_full_history()

        # First pass: identify invalid seeds
        invalid_seeds = set()
        discarded_runs: List[str] = []

        for method in self.methods_list:
            for seed in self.seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                if df.empty:
                    raise ValueError(f"No rows found for run '{run_name}' in full history.")

                # Required columns
                missing = [k for k in (cost_key, budget_key) if k not in df.columns]
                if missing:
                    raise ValueError(f"Missing columns {missing} in {run_name}")

                costs = df[cost_key].to_numpy(dtype=float)
                
                # Budget validation
                bud_vals = df[budget_key].to_numpy(dtype=float)
                bud_vals = bud_vals[~np.isnan(bud_vals)]
                if bud_vals.size == 0:
                    raise ValueError(f"No finite '{budget_key}' in {run_name}")
                unique_b = np.unique(bud_vals)
                if unique_b.size != 1:
                    raise ValueError(f"Non-constant '{budget_key}' in {run_name}: {unique_b.tolist()}")
                run_budget = float(unique_b[0])

                effective_budget = budget if budget is not None else run_budget
                
                # Check after budget clipping
                m = costs <= effective_budget
                costs_clipped = costs[m]
                if costs_clipped.size == 0:
                    if seed not in invalid_seeds:
                        print(f"Warning: Seed {seed} has 0 points ≤ effective budget {effective_budget} in {method}; excluding seed {seed} from ALL methods.")
                        invalid_seeds.add(seed)
                        discarded_runs.append(f"Seed {seed} (0 points ≤ effective budget {effective_budget} in {method})")
                    continue

        # Second pass: collect valid runtime data
        valid_seed_list = [seed for seed in self.seed_list if seed not in invalid_seeds]
        results: Dict[str, Dict[str, any]] = {}
        
        for method in self.methods_list:
            runtime_values_per_seed = []
            
            for seed in valid_seed_list:
                run_name = f"{self.problem_name}_{method}_run.{seed}"
                df = full_history_df[full_history_df["run_name"] == run_name].copy()

                if df.empty:
                    continue

                # Check if runtime_key exists
                if runtime_key not in df.columns:
                    print(f"Warning: Column '{runtime_key}' not found in {run_name}. Skipping.")
                    continue

                # Get costs and runtimes
                costs = df[cost_key].to_numpy(dtype=float)
                runtimes = df[runtime_key].to_numpy(dtype=float)
                
                # Remove NaN values
                valid_mask = ~(np.isnan(costs) | np.isnan(runtimes))
                costs = costs[valid_mask]
                runtimes = runtimes[valid_mask]
                
                # Get budget for this run if budget clipping is needed
                if budget is not None or cost_key in df.columns:
                    bud_vals = df[budget_key].to_numpy(dtype=float)
                    bud_vals = bud_vals[~np.isnan(bud_vals)]
                    if len(bud_vals) > 0:
                        unique_b = np.unique(bud_vals)
                        run_budget = float(unique_b[0])
                        
                        # Use target budget if specified, otherwise use run's budget
                        effective_budget = budget if budget is not None else run_budget
                        
                        # Clip within effective budget
                        m = costs <= effective_budget
                        runtimes = runtimes[m]
                
                # Store all runtime values for this seed
                runtime_values_per_seed.extend(runtimes.tolist())
            
            if not runtime_values_per_seed:
                print(f"Warning: No valid runtime data found for method '{method}'. Skipping.")
                continue
            
            # Calculate statistics
            runtime_array = np.array(runtime_values_per_seed)
            mean_runtime = runtime_array.mean()
            std_runtime = runtime_array.std()
            median_runtime = np.median(runtime_array)
            
            results[method] = {
                "mean": mean_runtime,
                "std": std_runtime,
                "median": median_runtime,
                "all_values": runtime_values_per_seed,
            }
            
            print(f"{method}: Runtime = {mean_runtime:.4f}s ± {std_runtime:.4f}s (median: {median_runtime:.4f}s)")

        # Print summary of discarded seeds
        if invalid_seeds:
            print(f"\nDiscarded seeds {sorted(invalid_seeds)} from ALL methods due to insufficient data points:")
            for run_info in discarded_runs:
                print(f"  - {run_info}")
            print(f"Using {len(valid_seed_list)}/{len(self.seed_list)} seeds for comparison.")

        # Create boxplot
        fig, ax = plt.subplots(figsize=(10, 6))
        
        runtime_data = [results[m]["all_values"] for m in self.methods_list]
        
        # Get colors for each method
        colors = [self.method_colors.get(m) if self.method_colors else None for m in self.methods_list]
        
        bp = ax.boxplot(runtime_data, tick_labels=self.methods_list, patch_artist=True,
                        showmeans=True, meanline=False,
                        boxprops=dict(linewidth=1.5, edgecolor='#34495E'),
                        whiskerprops=dict(linewidth=1.5, color='#34495E'),
                        capprops=dict(linewidth=1.5, color='#34495E'),
                        medianprops=dict(linewidth=2, color='#E74C3C'),
                        meanprops=dict(marker='D', markerfacecolor='#27AE60', markeredgecolor='#27AE60', markersize=8))
        
        # Color the boxes
        for patch, color in zip(bp['boxes'], colors):
            if color:
                patch.set_facecolor(color)
                patch.set_alpha(0.6)
            else:
                patch.set_facecolor('#E8F4F8')
                patch.set_alpha(0.7)
        
        ax.set_xlabel('Method', fontsize=14, fontweight='bold')
        ax.set_ylabel('Runtime (seconds)', fontsize=14, fontweight='bold')
        ttl = title or f"{self.problem_name}: Runtime Distribution"
        ax.set_title(ttl, fontsize=15, fontweight='bold', pad=15)
        ax.tick_params(axis='both', labelsize=12)
        plt.xticks(rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.3, linestyle='--', linewidth=0.8)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        # Add legend for mean and median
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color='#E74C3C', linewidth=2, label='Median'),
            Line2D([0], [0], marker='D', color='w', markerfacecolor='#27AE60', 
                   markeredgecolor='#27AE60', markersize=8, label='Mean')
        ]
        ax.legend(handles=legend_elements, loc='best', fontsize=11, framealpha=0.95,
                 edgecolor='#34495E', fancybox=True, shadow=True)
        
        plt.tight_layout()
        plt.show()
        
        # Save results to CSV if requested
        if self.save_results_csv:
            results_dir = self.results_dir
            if results_dir is None:
                script_dir = os.path.dirname(os.path.abspath(__file__))
                results_dir = os.path.join(script_dir, "results", "data")
            
            os.makedirs(results_dir, exist_ok=True)
            csv_filename = f"{self.problem_name}_runtime.csv"
            csv_path = os.path.join(results_dir, csv_filename)
            
            # Build DataFrame with all runtime values
            df_data = {"method": [], "runtime": []}
            
            for method in self.methods_list:
                runtime_vals = results[method]["all_values"]
                df_data["method"].extend([method] * len(runtime_vals))
                df_data["runtime"].extend(runtime_vals)
            
            df = pd.DataFrame(df_data)
            df.to_csv(csv_path, index=False)
            print(f"Runtime results saved to: {csv_path}")
        
        return results


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Visualize experiment results from W&B")
    parser.add_argument("--entity", type=str, required=True,
                        help="W&B entity name")
    parser.add_argument("--project", type=str, required=True,
                        help="W&B project name (e.g., rescue_Park4D)")
    parser.add_argument("--problem", type=str, required=True,
                        help="Problem name (e.g., Park4D, BraninCurrin)")
    parser.add_argument("--seed_list", type=str, default="1-20",
                        help="Seed list specification. Use format '1-20' for range or '1,2,3' for specific seeds")
    parser.add_argument("--save_results", action="store_true",
                        help="Save results to CSV files")
    parser.add_argument("--has_constraints", action="store_true",
                        help="If set, plot violations bar chart")
    parser.add_argument("--n_grid", type=int, default=None,
                        help="Number of points in the interpolation grid (default: None, automatically set to budget)")
    parser.add_argument("--discrete_fidelities", type=str, default=None,
                        help="Comma-separated list of discrete fidelity values (e.g., '0.25,0.5,0.75,1.0'). If provided, displays fidelity counts instead of distribution.")
    
    args = parser.parse_args()
    
    # Parse seed_list
    if "-" in args.seed_list:
        start, end = map(int, args.seed_list.split("-"))
        seed_list = list(range(start, end + 1))
    else:
        seed_list = list(map(int, args.seed_list.split(",")))
    
    # Parse discrete_fidelities
    discrete_fidelities = None
    if args.discrete_fidelities:
        discrete_fidelities = [float(x.strip()) for x in args.discrete_fidelities.split(",")]
    
    api = wandb.Api()
    
    methods = ["mf_gp_qehvi", "mf_gp_momf", "mf_gp_hvkg", "rescue"]
    method_colors = {
        "rescue": "red",
        "mf_gp_momf": "blue", 
        "mf_gp_hvkg": "green",
        "mf_gp_qehvi": "purple"
    }

    # Create visualizer instance
    viz = Visualizer(
        problem_name=args.problem,
        methods_list=methods,
        seed_list=seed_list,
        api=api,
        entity=args.entity,
        project=args.project,
        method_colors=method_colors,
        save_results_csv=True,
    )

    # Plot regret vs cost
    results = viz.plot_regret_vs_cost(
            regret_key="best_nsga2_regret",
            n_grid=args.n_grid,
    )


    # # Plot violations bar chart if has_constraints is set
    if args.has_constraints:
        violation_results = viz.plot_violations_bar(
            title=f"{args.problem}: Total Violations Comparison"
        )
    
    # Plot fidelity and iterations analysis
    fidelity_results = viz.plot_fidelity_and_iterations(
        title=f"{args.problem}: Fidelity Usage and Iterations",
        colormap="autumn",
        discrete_fidelities=discrete_fidelities,
    )
    
    # Plot runtime distribution
    runtime_results = viz.plot_runtime_boxplot(
        title=f"{args.problem}: Runtime Distribution"
    )
    

    # # Analyze area under curve
    auc_results = analyze_area_under_curve(
        results=results,
        base_method="rescue",
        compare_methods=["mf_gp_qehvi", "mf_gp_momf", "mf_gp_hvkg"]
    )
    
    # Save statistical results to CSV
    script_dir = os.path.dirname(os.path.abspath(__file__))
    results_dir = os.path.join(script_dir, "results", "data")
    os.makedirs(results_dir, exist_ok=True)

    if auc_results and 'comparisons' in auc_results:
        auc_rows = []
        base_method = auc_results['base_method']
        base_auc = auc_results.get('base_method_auc', {})
        
        for method_name, stats in auc_results['comparisons'].items():
            row = {
                'base_method': base_method,
                'comparison_method': method_name,
                'base_mean_auc': base_auc.get('mean'),
                'base_std_auc': base_auc.get('std'),
                'base_n_seeds': base_auc.get('n_seeds'),
                'method_mean_auc': stats.get('mean_auc'),
                'method_std_auc': stats.get('std_auc'),
                'n_seeds': stats.get('n_seeds'),
                'auc_difference': stats.get('auc_difference'),
                'percent_difference': stats.get('percent_difference'),
                't_statistic': stats.get('t_statistic'),
                'p_value': stats.get('p_value'),
                'is_significant': stats.get('is_significant'),
                'cohens_d': stats.get('cohens_d')
            }
            auc_rows.append(row)
        
        if auc_rows:
            df_auc = pd.DataFrame(auc_rows)
            auc_csv_path = os.path.join(results_dir, f"{args.problem}_auc_results.csv")
            df_auc.to_csv(auc_csv_path, index=False)
            print(f"AUC analysis results saved to: {auc_csv_path}")
