#!/usr/bin/env python3
"""
simple_plotting.py

A simplified plotting module that creates cleaner plots without unrestricted hybrid models.
Designed for publication-quality figures.
"""
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, Optional, List, Tuple

def plot_cost_vs_accuracy_simple(final_results: Dict[str, Any], 
                                 use_kernel: bool = False, 
                                 output_dir: Optional[str] = None):
    """
    Create a simplified plot showing accuracy vs. cost with clean formatting.
    - No unrestricted hybrid
    - No text annotations
    - Legend in bottom right
    - Normal error bars
    
    Parameters:
    -----------
    final_results : dict
        Dictionary containing the evaluation results.
    use_kernel : bool
        Whether the results used the kernel-based localized conformal method.
    output_dir : str
        Directory to save the plots. If None, saves in current directory.
    """
    if output_dir is None:
        output_dir = ""

    # Extract data
    alphas = list(final_results["hybrid_models"].keys())
    
    # Get essential information
    method = "Localized" if use_kernel else "Standard"
    iters = final_results["iterations"]
    subject = final_results.get("subject", "Unknown")
    calib_size = final_results.get("calibration_size", "Unknown")
    seed = final_results.get("random_seed", "Unknown")

    # Baseline models
    small_avg = final_results["small_model"]["avg_accuracy"]
    small_std = final_results["small_model"]["std_accuracy"]
    small_cost_avg = final_results["small_model"]["avg_cost"]
    small_cost_std = final_results["small_model"]["std_cost"]
    
    large_avg = final_results["large_model"]["avg_accuracy"]
    large_std = final_results["large_model"]["std_accuracy"]
    large_cost_avg = final_results["large_model"]["avg_cost"]
    large_cost_std = final_results["large_model"]["std_cost"]

    # Hybrid models
    hybrid_avg = {alpha: final_results["hybrid_models"][alpha]["avg_accuracy"]
                  for alpha in alphas}
    hybrid_std = {alpha: final_results["hybrid_models"][alpha]["std_accuracy"]
                  for alpha in alphas}
    hybrid_cost_avg = {alpha: final_results["hybrid_models"][alpha]["avg_cost"]
                       for alpha in alphas}
    hybrid_cost_std = {alpha: final_results["hybrid_models"][alpha]["std_cost"]
                       for alpha in alphas}

    # Large model usage
    large_model_usage = {alpha: final_results["hybrid_models"][alpha]["avg_large_model_usage"]
                         for alpha in alphas}
    
    # Extract random baseline data if available
    has_random_baseline = "random_baseline" in final_results
    if has_random_baseline:
        random_avg = {alpha: final_results["random_baseline"][alpha]["avg_accuracy"] 
                      for alpha in alphas if alpha in final_results["random_baseline"]}
        random_std = {alpha: final_results["random_baseline"][alpha]["std_accuracy"] 
                      for alpha in alphas if alpha in final_results["random_baseline"]}
        random_cost_avg = {alpha: final_results["random_baseline"][alpha]["avg_cost"]
                           for alpha in alphas if alpha in final_results["random_baseline"]}
        random_cost_std = {alpha: final_results["random_baseline"][alpha]["std_cost"]
                           for alpha in alphas if alpha in final_results["random_baseline"]}

    # Extract individual trial data if available
    has_trial_data = "all_trials" in final_results
    
    # Start plotting
    plt.figure(figsize=(10, 8))

    # Plot baseline models with normal error bars (with black edges)
    plt.errorbar(small_cost_avg, small_avg, 
                 yerr=small_std, xerr=small_cost_std, 
                 fmt='*', color='blue',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, markersize=15, 
                 label='Small Model Only', zorder=10)
    
    plt.errorbar(large_cost_avg, large_avg, 
                 yerr=large_std, xerr=large_cost_std, 
                 fmt='*', color='red',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, markersize=15, 
                 label='Large Model Only', zorder=10)

    # Draw linear interpolation line between baseline models
    x_interp = [small_cost_avg, large_cost_avg]
    y_interp = [small_avg, large_avg]
    y_interp_upper = [small_avg + small_std, large_avg + large_std]
    y_interp_lower = [small_avg - small_std, large_avg - large_std]
    
    plt.plot(x_interp, y_interp, '--', color='gray', alpha=0.5, label='Interpolation')
    plt.fill_between(x_interp, y_interp_lower, y_interp_upper, color='gray', alpha=0.1)

    # Plot degradation lines for each alpha with proper shading for error bounds
    n = len(alphas)
    n_low  = n // 2
    n_high = n - n_low

    low_ts  = np.linspace(0.1, 0.3, n_low,  endpoint=False)
    high_ts = np.linspace(0.7, 0.9, n_high, endpoint=True)
    ts      = np.concatenate([low_ts, high_ts])
    colors  = plt.cm.bwr(ts)

    for i, alpha in enumerate(alphas):
        degraded = large_avg * (1 - float(alpha))
        degraded_hi = (large_avg + large_std) * (1 - float(alpha))
        degraded_lo = (large_avg - large_std) * (1 - float(alpha))

        plt.axhline(degraded, color=colors[i], linestyle=':', alpha=0.6)
        plt.axhspan(degraded_lo, degraded_hi, color=colors[i], alpha=0.15)
        
        plt.annotate(f'α={alpha}',
                     xy=(0.01, degraded), xycoords=('axes fraction', 'data'),
                     va='center', ha='left',
                     color=colors[i], fontweight='bold', 
                     bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.2', edgecolor=colors[i]))

    # Plot scatter points for individual trials if available
    if has_trial_data:
        for i, alpha in enumerate(alphas):
            trial_costs = []
            trial_accuracy = []
            for trial in final_results["all_trials"]:
                if alpha in trial["hybrid_results"]:
                    if "avg_cost" in trial["hybrid_results"][alpha]:
                        trial_costs.append(trial["hybrid_results"][alpha]["avg_cost"])
                        trial_accuracy.append(trial["hybrid_results"][alpha]["accuracy"])
            plt.scatter(trial_costs, trial_accuracy, 
                        color=colors[i], alpha=0.3, s=20, marker='o',
                        edgecolors='black', linewidths=1)

    # Plot hybrid model points with normal error bars (with black edges)
    legend_handles = []
    for i, alpha in enumerate(alphas):
        x = hybrid_cost_avg[alpha]
        xerr = hybrid_cost_std[alpha]
        y = hybrid_avg[alpha]
        yerr = hybrid_std[alpha]
        err = plt.errorbar(x, y, xerr=xerr, yerr=yerr,
                           fmt='o', color=colors[i],
                           markeredgecolor='black', markeredgewidth=1.5,
                           capsize=6, capthick=2, markersize=10,
                           zorder=5, elinewidth=2)
        if i == 0:
            legend_handles.append((err, 'Conformal Method'))
    
    # Plot random baseline if available (with hollow diamonds outlined in black)
    if has_random_baseline:
        for i, alpha in enumerate(alphas):
            if alpha in random_avg:
                err = plt.errorbar(random_cost_avg[alpha], random_avg[alpha],
                                   xerr=random_cost_std[alpha], yerr=random_std[alpha],
                                   fmt='D', 
                                   markerfacecolor='white',
                                   markeredgecolor='black', markeredgewidth=1.5,
                                   capsize=6, markersize=9,
                                   zorder=5, elinewidth=2)
                if i == 0:
                    legend_handles.append((err, 'Random Baseline'))

    # Horizontal lines for baselines
    plt.axhline(large_avg, color='blue', linestyle='--', alpha=0.3)
    plt.axhline(small_avg, color='red', linestyle='--', alpha=0.3)

    # Labels and legend
    plt.xlabel('Average Cost per Example ($)', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)

    handles, labels = plt.gca().get_legend_handles_labels()
    for h, l in legend_handles:
        handles.append(h)
        labels.append(l)
    plt.legend(handles, labels, loc='lower right', fontsize=12)
    
    plt.grid(True, alpha=0.3)

    # Save the plot
    tag = "localized" if use_kernel else "standard"
    # Extract dataset name from subject
    dataset = subject.split("_")[0] if "_" in subject else subject
    fname = f"simple_{dataset}_{tag}_cost_vs_accuracy_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    plt.close()
    
    return fname


def plot_enhanced_performance_simple(final_results: Dict[str, Any], 
                                     use_kernel: bool = False, 
                                     output_dir: Optional[str] = None):
    """
    Create a simplified enhanced plot of conformal alignment results.
    - No unrestricted hybrid
    - Clean error bars
    - No annotations on data points
    - Legend in bottom right
    
    Parameters:
    -----------
    final_results : dict
        Dictionary containing the evaluation results.
    use_kernel : bool
        Whether the results used the kernel-based localized conformal method.
    output_dir : str
        Directory to save the plots. If None, saves in current directory.
    """
    if output_dir is None:
        output_dir = ""

    alphas = list(final_results["hybrid_models"].keys())
    small_avg = final_results["small_model"]["avg_accuracy"]
    small_std = final_results["small_model"]["std_accuracy"]
    small_x = final_results.get("small_model", {}).get("x_position", 1.0)
    small_x_std = final_results.get("small_model", {}).get("x_std", 0.0)
    large_avg = final_results["large_model"]["avg_accuracy"]
    large_std = final_results["large_model"]["std_accuracy"]
    large_x = final_results.get("large_model", {}).get("x_position", 0.0)
    large_x_std = final_results.get("large_model", {}).get("x_std", 0.0)

    hybrid_avg = {alpha: final_results["hybrid_models"][alpha]["avg_accuracy"] for alpha in alphas}
    hybrid_std = {alpha: final_results["hybrid_models"][alpha]["std_accuracy"] for alpha in alphas}
    large_model_usage_std = {alpha: final_results["hybrid_models"][alpha]["std_large_model_usage"] for alpha in alphas}
    small_model_usage = {alpha: 1.0 - final_results["hybrid_models"][alpha]["avg_large_model_usage"] for alpha in alphas}
    small_model_usage_std = large_model_usage_std

    has_random_baseline = "random_baseline" in final_results
    if has_random_baseline:
        random_avg = {alpha: final_results["random_baseline"][alpha]["avg_accuracy"]
                      for alpha in alphas if alpha in final_results["random_baseline"]}
        random_std = {alpha: final_results["random_baseline"][alpha]["std_accuracy"]
                      for alpha in alphas if alpha in final_results["random_baseline"]}

    has_trial_data = "all_trials" in final_results

    plt.figure(figsize=(10, 8))

    # Baseline models
    plt.errorbar(large_x, large_avg, yerr=large_std, xerr=large_x_std, fmt='*', color='blue',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, markersize=15, label='Large Model Only', zorder=10)
    plt.errorbar(small_x, small_avg, yerr=small_std, xerr=small_x_std, fmt='*', color='red',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, markersize=15, label='Small Model Only', zorder=10)

    # Interpolation
    x_interp = [large_x, small_x]
    y_interp = [large_avg, small_avg]
    y_interp_upper = [large_avg + large_std, small_avg + small_std]
    y_interp_lower = [large_avg - large_std, small_avg - small_std]
    plt.plot(x_interp, y_interp, '--', color='gray', alpha=0.5, label='Interpolation')
    plt.fill_between(x_interp, y_interp_lower, y_interp_upper, color='gray', alpha=0.1)

    # Degradation lines
    n = len(alphas)
    n_low = n // 2
    low_ts = np.linspace(0.1, 0.3, n_low, endpoint=False)
    high_ts = np.linspace(0.7, 0.9, n - n_low, endpoint=True)
    ts = np.concatenate([low_ts, high_ts])
    colors = plt.cm.bwr(ts)

    for i, alpha in enumerate(alphas):
        degraded = large_avg * (1 - float(alpha))
        degraded_hi = (large_avg + large_std) * (1 - float(alpha))
        degraded_lo = (large_avg - large_std) * (1 - float(alpha))
        plt.axhline(degraded, color=colors[i], linestyle=':', alpha=0.6)
        plt.axhspan(degraded_lo, degraded_hi, color=colors[i], alpha=0.15)
        plt.annotate(f'α={alpha}',
                     xy=(0.01, degraded), xycoords=('axes fraction', 'data'),
                     va='center', ha='left',
                     color=colors[i], fontweight='bold',
                     bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.2', edgecolor=colors[i]))

    # Trial scatter
    if has_trial_data:
        for i, alpha in enumerate(alphas):
            trial_small_usage = []
            trial_accuracy = []
            for trial in final_results["all_trials"]:
                if alpha in trial["hybrid_results"]:
                    trial_small_usage.append(1.0 - trial["hybrid_results"][alpha]["large_model_usage"])
                    trial_accuracy.append(trial["hybrid_results"][alpha]["accuracy"])
            plt.scatter(trial_small_usage, trial_accuracy,
                        color=colors[i], alpha=0.3, s=20, marker='o',
                        edgecolors='black', linewidths=1)

    # Hybrid-model points
    legend_handles = []
    for i, alpha in enumerate(alphas):
        x = small_model_usage[alpha]
        xerr = small_model_usage_std[alpha]
        y = hybrid_avg[alpha]
        yerr = hybrid_std[alpha]
        err = plt.errorbar(x, y, xerr=xerr, yerr=yerr,
                           fmt='o', color=colors[i],
                           markeredgecolor='black', markeredgewidth=1.5,
                           capsize=6, capthick=2, markersize=10,
                           zorder=5, elinewidth=2)
        if i == 0:
            legend_handles.append((err, 'Conformal Method'))

    # Random baseline
    if has_random_baseline:
        for i, alpha in enumerate(alphas):
            if alpha in random_avg:
                err = plt.errorbar(small_model_usage[alpha], random_avg[alpha],
                                   yerr=random_std[alpha],
                                   fmt='D',
                                   markerfacecolor='white',
                                   markeredgecolor='black', markeredgewidth=1.5,
                                   capsize=6, markersize=9,
                                   zorder=5, elinewidth=2)
                if i == 0:
                    legend_handles.append((err, 'Random Baseline'))

    plt.axhline(large_avg, color='blue', linestyle='--', alpha=0.3)
    plt.axhline(small_avg, color='red', linestyle='--', alpha=0.3)

    plt.xlabel('Fraction of Small Model Calls', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)

    handles, labels = plt.gca().get_legend_handles_labels()
    for h, l in legend_handles:
        handles.append(h)
        labels.append(l)
    plt.legend(handles, labels, loc='lower right', fontsize=12)
    
    plt.grid(True, alpha=0.3)
    plt.xlim(-0.1, 1.1)

    tag = "localized" if use_kernel else "standard"
    iters = final_results["iterations"]
    subject = final_results.get("subject", "Unknown")
    # Extract dataset name from subject
    dataset = subject.split("_")[0] if "_" in subject else subject
    fname = f"simple_{dataset}_{tag}_enhanced_performance_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    plt.close()

    return fname


def plot_cost_vs_accuracy_improved(final_results: dict, 
                                   use_kernel: bool = False, 
                                   output_dir: str = None):
    """
    Create an improved plot showing accuracy vs. cost with clean formatting.
    - Different marker styles for conformal vs random baseline
    - Random baseline has asymmetric dashed error bars drawn manually with true ±SD extents and caps
    - No boxes around individual markers in legend
    - No title
    
    Parameters:
    -----------
    final_results : dict
        Dictionary containing the evaluation results.
    use_kernel : bool
        Whether the results used the kernel-based localized conformal method.
    output_dir : str
        Directory to save the plots. If None, saves in current directory.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D

    if output_dir is None:
        output_dir = ""

    # Extract data
    alphas = list(final_results["hybrid_models"].keys())
    iters = final_results["iterations"]
    subject = final_results.get("subject", "Unknown")

    # Baseline models
    small_avg     = final_results["small_model"]["avg_accuracy"]
    small_std     = final_results["small_model"]["std_accuracy"]
    small_cost    = final_results["small_model"]["avg_cost"]
    small_cost_sd = final_results["small_model"]["std_cost"]

    large_avg     = final_results["large_model"]["avg_accuracy"]
    large_std     = final_results["large_model"]["std_accuracy"]
    large_cost    = final_results["large_model"]["avg_cost"]
    large_cost_sd = final_results["large_model"]["std_cost"]

    # Hybrid models
    hybrid_avg      = {a: final_results["hybrid_models"][a]["avg_accuracy"] for a in alphas}
    hybrid_std      = {a: final_results["hybrid_models"][a]["std_accuracy"]  for a in alphas}
    hybrid_cost     = {a: final_results["hybrid_models"][a]["avg_cost"]      for a in alphas}
    hybrid_cost_sd  = {a: final_results["hybrid_models"][a]["std_cost"]     for a in alphas}

    # Random baseline
    has_random = "random_baseline" in final_results
    if has_random:
        random_avg     = {a: final_results["random_baseline"][a]["avg_accuracy"] for a in alphas if a in final_results["random_baseline"]}
        random_std     = {a: final_results["random_baseline"][a]["std_accuracy"]  for a in alphas if a in final_results["random_baseline"]}
        random_cost    = {a: final_results["random_baseline"][a]["avg_cost"]      for a in alphas if a in final_results["random_baseline"]}
        random_cost_sd = {a: final_results["random_baseline"][a]["std_cost"]     for a in alphas if a in final_results["random_baseline"]}

    # Trials
    has_trials = "all_trials" in final_results

    plt.figure(figsize=(10, 8))

    # Degradation lines
    n = len(alphas)
    low_ts  = np.linspace(0.1, 0.3, n//2, endpoint=False)
    high_ts = np.linspace(0.7, 0.9, n - n//2, endpoint=True)
    colors  = plt.cm.bwr(np.concatenate([low_ts, high_ts]))
    for i, a in enumerate(alphas):
        base = 1 - float(a)
        m    = large_avg * base
        hi   = (large_avg + large_std) * base
        lo   = (large_avg - large_std) * base
        plt.axhline(m, color=colors[i], linestyle=':', alpha=0.6)
        plt.axhspan(lo, hi, color=colors[i], alpha=0.15)
        plt.annotate(f'α={a}',
                     xy=(0.01, m), xycoords=('axes fraction','data'),
                     va='center', ha='left',
                     color=colors[i], fontweight='bold',
                     bbox=dict(facecolor='white', alpha=0.9,
                               boxstyle='round,pad=0.2',
                               edgecolor=colors[i]))

    # Linear interpolation
    xint = [small_cost, large_cost]
    yint = [small_avg, large_avg]
    yup = [small_avg + small_std, large_avg + large_std]
    ydn = [small_avg - small_std, large_avg - large_std]
    plt.plot(xint, yint, '--', color='gray', alpha=0.5)
    plt.fill_between(xint, ydn, yup, color='gray', alpha=0.1)

    # Trial points
    if has_trials:
        for i, a in enumerate(alphas):
            cs, accs = [], []
            for t in final_results["all_trials"]:
                if a in t["hybrid_results"]:
                    cs.append(t["hybrid_results"][a]["avg_cost"])
                    accs.append(t["hybrid_results"][a]["accuracy"])
            plt.scatter(cs, accs,
                        marker='o',
                        edgecolors='black', linewidths=1,
                        color=colors[i], alpha=0.3, s=20)

    # Hybrid points
    for i, a in enumerate(alphas):
        x, xsd = hybrid_cost[a], hybrid_cost_sd[a]
        y, ysd = hybrid_avg[a],   hybrid_std[a]
        plt.errorbar(x, y, xerr=xsd, yerr=ysd,
                     fmt='o', color=colors[i],
                     markeredgecolor='black', markeredgewidth=1.5,
                     capsize=6, capthick=2, markersize=10,
                     elinewidth=2, zorder=5)

    # Random-baseline (manual ±SD + caps)
    if has_random:
        # color map for each alpha
        low_ts  = np.linspace(0.1, 0.3, len(alphas)//2, endpoint=False)
        high_ts = np.linspace(0.7, 0.9, len(alphas) - len(alphas)//2, endpoint=True)
        colors  = plt.cm.bwr(np.concatenate([low_ts, high_ts]))

        for i, a in enumerate(alphas):
            if a not in random_avg: 
                continue

            x,  xsd = random_cost[a],    random_cost_sd[a]
            y,  ysd = random_avg[a],     random_std[a]

            # hollow diamond
            plt.plot(x, y, 'D',
                     markerfacecolor='white',
                     markeredgecolor='black',
                     markeredgewidth=1.5,
                     markersize=9, zorder=5)

            # exact endpoints
            xl, xr = x - xsd, x + xsd
            yb, yt = y - ysd, y + ysd

            # dashed bars
            plt.hlines(y, xl, xr,
                       colors=colors[i],
                       linestyles='dashed',
                       linewidth=2, zorder=4)
            plt.vlines(xl, yb, y,
                       colors=colors[i],
                       linestyles='dashed',
                       linewidth=2, zorder=4)
            plt.vlines(xr, y, yt,
                       colors=colors[i],
                       linestyles='dashed',
                       linewidth=2, zorder=4)

            # caps: bottom cap at xl, top cap at xr
            cap = min(xsd, ysd) * 0.2
            # bottom horizontal cap at (xl, yb)
            plt.plot([xl - cap, xl + cap], [yb, yb],
                     color=colors[i], linewidth=2, zorder=5)
            # top horizontal cap at (xr, yt)
            plt.plot([xr - cap, xr + cap], [yt, yt],
                     color=colors[i], linewidth=2, zorder=5)

            # vertical‐bar caps (still centered on diamond)
            plt.plot([xl, xl], [y - cap, y + cap],
                     color=colors[i], linewidth=2, zorder=5)
            plt.plot([xr, xr], [y - cap, y + cap],
                     color=colors[i], linewidth=2, zorder=5)

    # … rest of plotting: baseline stars, legend, grid, save …

    plt.errorbar(small_cost, small_avg, xerr=small_cost_sd, yerr=small_std,
                 fmt='*', color='red',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, capthick=2, markersize=15, zorder=10)
    plt.errorbar(large_cost, large_avg, xerr=large_cost_sd, yerr=large_std,
                 fmt='*', color='blue',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, capthick=2, markersize=15, zorder=10)

    plt.xlabel('Average Cost per Example ($)', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    legend_handles = [
        Line2D([], [], marker='*', color='red',
               markeredgecolor='black', markeredgewidth=1.5,
               linestyle='None', markersize=15),
        Line2D([], [], marker='*', color='blue',
               markeredgecolor='black', markeredgewidth=1.5,
               linestyle='None', markersize=15),
        Line2D([], [], marker='o', color=plt.cm.bwr(0.3),
               markeredgecolor='black', markeredgewidth=1.5,
               linestyle='None', markersize=10),
        Line2D([], [], marker='D', color='black',
               markerfacecolor='white',
               markeredgecolor='black', markeredgewidth=1.5,
               linestyle='None', markersize=9),
        Line2D([], [], color='gray', linestyle='--', alpha=0.5),
    ]
    legend_labels = [
        'Primary',
        'Guardian',
        'Conformal Arbitrage',
        'Random Routing',
        'Linear Interpolation',
    ]
    plt.legend(handles=legend_handles, labels=legend_labels,
               loc='lower right', fontsize=12, frameon=True)

    plt.grid(True, alpha=0.3)

    tag = "localized" if use_kernel else "standard"
    # Extract dataset name from subject
    dataset = subject.split("_")[0] if "_" in subject else subject
    fname = f"improved_{dataset}_{tag}_cost_vs_accuracy_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    plt.close()
    return fname


def plot_enhanced_performance_improved(final_results: dict, 
                                       use_kernel: bool = False, 
                                       output_dir: str = None):
    """
    Create an enhanced plot of conformal alignment results with improved styling.
    - Shows model performance vs fraction of small model calls
    - Different marker styles for different methods
    - Random baseline has asymmetric dashed error bars
    - No boxes around individual markers in legend
    - No title
    
    Parameters:
    -----------
    final_results : dict
        Dictionary containing the evaluation results.
    use_kernel : bool
        Whether the results used the kernel-based localized conformal method.
    output_dir : str
        Directory to save the plots. If None, saves in current directory.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D

    if output_dir is None:
        output_dir = ""

    alphas = list(final_results["hybrid_models"].keys())
    small_avg = final_results["small_model"]["avg_accuracy"]
    small_std = final_results["small_model"]["std_accuracy"]
    small_x = final_results.get("small_model", {}).get("x_position", 1.0)
    small_x_std = final_results.get("small_model", {}).get("x_std", 0.0)
    large_avg = final_results["large_model"]["avg_accuracy"]
    large_std = final_results["large_model"]["std_accuracy"]
    large_x = final_results.get("large_model", {}).get("x_position", 0.0)
    large_x_std = final_results.get("large_model", {}).get("x_std", 0.0)

    hybrid_avg = {alpha: final_results["hybrid_models"][alpha]["avg_accuracy"] for alpha in alphas}
    hybrid_std = {alpha: final_results["hybrid_models"][alpha]["std_accuracy"] for alpha in alphas}
    large_model_usage_std = {alpha: final_results["hybrid_models"][alpha]["std_large_model_usage"] for alpha in alphas}
    small_model_usage = {alpha: 1.0 - final_results["hybrid_models"][alpha]["avg_large_model_usage"] for alpha in alphas}
    small_model_usage_std = large_model_usage_std

    has_random_baseline = "random_baseline" in final_results
    if has_random_baseline:
        random_avg = {alpha: final_results["random_baseline"][alpha]["avg_accuracy"]
                      for alpha in alphas if alpha in final_results["random_baseline"]}
        random_std = {alpha: final_results["random_baseline"][alpha]["std_accuracy"]
                      for alpha in alphas if alpha in final_results["random_baseline"]}

    has_trial_data = "all_trials" in final_results

    plt.figure(figsize=(10, 8))

    # Degradation lines
    colors = plt.cm.viridis(np.linspace(0, 1, len(alphas)))
    for i, alpha in enumerate(alphas):
        degraded = large_avg * (1 - float(alpha))
        degraded_hi = (large_avg + large_std) * (1 - float(alpha))
        degraded_lo = (large_avg - large_std) * (1 - float(alpha))
        degradation_pct = float(alpha) * 100

        plt.axhline(degraded, color=colors[i], linestyle=':', alpha=0.6)
        plt.axhspan(degraded_lo, degraded_hi, color=colors[i], alpha=0.15)
        plt.annotate(f'α={alpha} (-{degradation_pct:.0f}%)',
                     xy=(0.01, degraded), xycoords=('axes fraction', 'data'),
                     va='center', ha='left',
                     color=colors[i], fontweight='bold',
                     bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.2', edgecolor=colors[i]))

    # Interpolation
    x_interp = [large_x, small_x]
    y_interp = [large_avg, small_avg]
    y_interp_upper = [large_avg + large_std, small_avg + small_std]
    y_interp_lower = [large_avg - large_std, small_avg - small_std]
    plt.plot(x_interp, y_interp, '--', color='gray', alpha=0.5)
    plt.fill_between(x_interp, y_interp_lower, y_interp_upper, color='gray', alpha=0.1)

    # Trial scatter
    if has_trial_data:
        for i, alpha in enumerate(alphas):
            trial_small_usage = []
            trial_accuracy = []
            for trial in final_results["all_trials"]:
                if alpha in trial["hybrid_results"]:
                    trial_small_usage.append(1.0 - trial["hybrid_results"][alpha]["large_model_usage"])
                    trial_accuracy.append(trial["hybrid_results"][alpha]["accuracy"])
            plt.scatter(trial_small_usage, trial_accuracy,
                        color=colors[i], alpha=0.3, s=20, marker='o',
                        edgecolors='black', linewidths=1)

    # Hybrid-model points
    for i, alpha in enumerate(alphas):
        x = small_model_usage[alpha]
        xerr = small_model_usage_std[alpha]
        y = hybrid_avg[alpha]
        yerr = hybrid_std[alpha]
        plt.errorbar(x, y, xerr=xerr, yerr=yerr,
                     fmt='o', color=colors[i],
                     markeredgecolor='black', markeredgewidth=1.5,
                     capsize=6, capthick=2, markersize=10,
                     zorder=5, elinewidth=2)

    # Random baseline
    if has_random_baseline:
        for i, alpha in enumerate(alphas):
            if alpha in random_avg:
                plt.plot(small_model_usage[alpha], random_avg[alpha], 'D',
                         markerfacecolor='white',
                         markeredgecolor='black', markeredgewidth=1.5,
                         markersize=9, zorder=5)

    # Baselines with annotations
    plt.annotate(f'Large Model\n({large_avg:.4f} ± {large_std:.4f})',
                 xy=(large_x+0.03, large_avg), ha='left', va='center', fontsize=8)
    plt.errorbar(large_x, large_avg, yerr=large_std, xerr=large_x_std,
                 fmt='*', color='red',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, markersize=15, zorder=10)

    plt.annotate(f'Small Model\n({small_avg:.4f} ± {small_std:.4f})',
                 xy=(small_x-0.03, small_avg), ha='right', va='center', fontsize=8)
    plt.errorbar(small_x, small_avg, yerr=small_std, xerr=small_x_std,
                 fmt='*', color='red',
                 markeredgecolor='black', markeredgewidth=1.5,
                 capsize=5, markersize=15, zorder=10)

    plt.xlabel('Fraction of Small Model Calls', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)

    legend_handles = [
        Line2D([], [], color='gray', linestyle='--', alpha=0.5),
        Line2D([], [], marker='*', color='red', linestyle='None', markersize=15),
        Line2D([], [], marker='*', color='blue', linestyle='None', markersize=15),
        Line2D([], [], marker='o', color=colors[0], linestyle='None', markersize=10),
        Line2D([], [], marker='D', color=colors[0], linestyle='None', markerfacecolor='white', markersize=9),
    ]
    legend_labels = [
        'Linear Interpolation',
        'Small Model Only',
        'Large Model Only',
        'Conformal Method',
        'Random Baseline',
    ]
    plt.legend(handles=legend_handles, labels=legend_labels, loc='lower right', fontsize=12, frameon=True)

    plt.grid(True, alpha=0.3)
    plt.xlim(-0.05, 1.05)
    y_min = min(small_avg - 2*small_std, min(hybrid_avg.values()) - 2*max(hybrid_std.values()))
    y_max = max(large_avg + 2*large_std, max(hybrid_avg.values()) + 2*max(hybrid_std.values()))
    plt.ylim(y_min, y_max)

    tag = "localized" if use_kernel else "standard"
    iters = final_results["iterations"]
    subject = final_results.get("subject", "Unknown")
    # Extract dataset name from subject
    dataset = subject.split("_")[0] if "_" in subject else subject
    fname = f"improved_{dataset}_{tag}_enhanced_performance_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    plt.close()

    return fname


def create_simple_plots(results_dir: str, use_kernel: bool = False):
    """
    Create simplified, publication-quality plots from existing results.
    - No unrestricted hybrid model
    - Clean error bars
    - Legend in bottom right
    
    Parameters:
    -----------
    results_dir : str
        Directory containing the result files
    use_kernel : bool
        Whether the results used kernel-based localized conformal method
    """
    method_name = "localized" if use_kernel else "standard"
    # Search for both mmlu and truthqa files
    final_files = [f for f in os.listdir(results_dir) 
                   if (f.startswith("mmlu_") or f.startswith("truthqa_")) and f"_{method_name}_final_results_" in f]
    if not final_files:
        print(f"No final results found in {results_dir} for {method_name} method")
        return None

    final_file = sorted(final_files)[-1]
    with open(os.path.join(results_dir, final_file), "r") as f:
        final_results = json.load(f)
    print(f"Loaded results from {final_file}")

    cost_plot = plot_cost_vs_accuracy_simple(final_results, use_kernel=use_kernel, output_dir=results_dir)
    perf_plot = plot_enhanced_performance_simple(final_results, use_kernel=use_kernel, output_dir=results_dir)
    improved_cost_plot = plot_cost_vs_accuracy_improved(final_results, use_kernel=use_kernel, output_dir=results_dir)
    improved_perf_plot = plot_enhanced_performance_improved(final_results, use_kernel=use_kernel, output_dir=results_dir)

    print("Created plots:")
    print(f"  - {cost_plot}")
    print(f"  - {perf_plot}")
    print(f"  - {improved_cost_plot}")
    print(f"  - {improved_perf_plot}")

    return final_results


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Create simplified, publication-quality plots from existing results"
    )
    parser.add_argument("--results_dir", type=str, required=True,
                        help="Directory containing result files")
    parser.add_argument("--use_kernel", action="store_true",
                        help="Whether results used kernel-based method")

    args = parser.parse_args()
    create_simple_plots(args.results_dir, args.use_kernel)
