#!/usr/bin/env python3
"""
unrestricted_plotting.py

A specialized plotting module that creates plots with both standard and unrestricted hybrid models,
with improved formatting for publication-quality figures and better visibility in black and white prints.
"""
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, Optional
from matplotlib.lines import Line2D

def plot_cost_vs_accuracy_unrestricted(final_results: Dict[str, Any], 
                                     use_kernel: bool = False, 
                                     output_dir: Optional[str] = None):
    """
    Create a plot showing accuracy vs. cost for different model configurations
    including both standard and unrestricted hybrid models.
    - Red/blue color palette matching simple_plotting.py
    - Asymmetric dashed error bars for baseline models
    - No individual trial points
    - Clean legend formatting
    """
    if output_dir is None:
        output_dir = ""

    # Extract data
    alphas = list(final_results["hybrid_models"].keys())

    # Essential info
    iters = final_results["iterations"]

    # Baseline models
    small_avg = final_results["small_model"]["avg_accuracy"]
    small_std = final_results["small_model"]["std_accuracy"]
    small_cost_avg = final_results["small_model"]["avg_cost"]
    small_cost_std = final_results["small_model"]["std_cost"]

    large_avg = final_results["large_model"]["avg_accuracy"]
    large_std = final_results["large_model"]["std_accuracy"]
    large_cost_avg = final_results["large_model"]["avg_cost"]
    large_cost_std = final_results["large_model"]["std_cost"]

    # Hybrid models
    hybrid_avg = {a: final_results["hybrid_models"][a]["avg_accuracy"] for a in alphas}
    hybrid_std = {a: final_results["hybrid_models"][a]["std_accuracy"] for a in alphas}
    hybrid_cost_avg = {a: final_results["hybrid_models"][a]["avg_cost"] for a in alphas}
    hybrid_cost_std = {a: final_results["hybrid_models"][a]["std_cost"] for a in alphas}

    # Unrestricted hybrid
    has_unrestricted = "unrestricted_hybrid" in final_results
    if has_unrestricted:
        unrestricted_avg = {a: final_results["unrestricted_hybrid"][a]["avg_accuracy"] for a in alphas}
        unrestricted_std = {a: final_results["unrestricted_hybrid"][a]["std_accuracy"] for a in alphas}
        unrestricted_cost_avg = {a: final_results["unrestricted_hybrid"][a]["avg_cost"] for a in alphas}
        unrestricted_cost_std = {a: final_results["unrestricted_hybrid"][a]["std_cost"] for a in alphas}

    # Random baseline
    has_random = "random_baseline" in final_results
    if has_random:
        random_avg = {a: final_results["random_baseline"][a]["avg_accuracy"] for a in alphas if a in final_results["random_baseline"]}
        random_std = {a: final_results["random_baseline"][a]["std_accuracy"] for a in alphas if a in final_results["random_baseline"]}
        random_cost = {a: final_results["random_baseline"][a]["avg_cost"] for a in alphas if a in final_results["random_baseline"]}
        random_cost_sd = {a: final_results["random_baseline"][a]["std_cost"] for a in alphas if a in final_results["random_baseline"]}

    # Unrestricted random
    has_unres_rand = "unrestricted_random_baseline" in final_results
    if has_unres_rand:
        unres_rand_avg = {a: final_results["unrestricted_random_baseline"][a]["avg_accuracy"] for a in alphas if a in final_results["unrestricted_random_baseline"]}
        unres_rand_std = {a: final_results["unrestricted_random_baseline"][a]["std_accuracy"] for a in alphas if a in final_results["unrestricted_random_baseline"]}
        unres_rand_cost = {a: final_results["unrestricted_random_baseline"][a]["avg_cost"] for a in alphas if a in final_results["unrestricted_random_baseline"]}
        unres_rand_cost_sd = {a: final_results["unrestricted_random_baseline"][a]["std_cost"] for a in alphas if a in final_results["unrestricted_random_baseline"]}

    plt.figure(figsize=(11,9))

    # Colormap: red/blue**
    n = len(alphas)
    n_low = n//2
    n_high = n - n_low
    low_ts = np.linspace(0.1,0.3,n_low,endpoint=False)
    high_ts = np.linspace(0.7,0.9,n_high,endpoint=True)
    ts = np.concatenate([low_ts,high_ts])
    colors = plt.cm.bwr(ts)

    # Degradation lines
    for i,a in enumerate(alphas):
        base = 1 - float(a)
        m = large_avg * base
        hi = (large_avg+large_std) * base
        lo = (large_avg-large_std) * base
        plt.axhline(m, color=colors[i], linestyle=':', alpha=0.7)
        plt.axhspan(lo,hi,color=colors[i],alpha=0.2)
        plt.annotate(f'α={a}', xy=(0.01,m), xycoords=('axes fraction','data'), va='center',ha='left',
                     color=colors[i],fontweight='bold',
                     bbox=dict(facecolor='white',alpha=0.9,boxstyle='round,pad=0.2',edgecolor=colors[i]))

    # Linear interpolation
    xint=[small_cost_avg,large_cost_avg]
    yint=[small_avg,large_avg]
    yup=[small_avg+small_std,large_avg+large_std]
    ydn=[small_avg-small_std,large_avg-large_std]
    plt.plot(xint,yint,'--',color='gray',alpha=0.6,linewidth=1.5,label='Linear Cost Interpolation')
    plt.fill_between(xint,ydn,yup,color='gray',alpha=0.15)

    # Baseline horizontal lines (color-coded by marker)
    plt.axhline(large_avg,color='blue',linestyle='--',alpha=0.4,linewidth=1.5)
    plt.axhline(small_avg,color='red',linestyle='--',alpha=0.4,linewidth=1.5)

    # Hybrid model points
    legend_handles=[]
    for i,a in enumerate(alphas):
        x=hybrid_cost_avg[a]; xerr=hybrid_cost_std[a]
        y=hybrid_avg[a]; yerr=hybrid_std[a]
        mark=plt.errorbar(x,y,xerr=xerr,yerr=yerr,fmt='o',color=colors[i],capsize=5,capthick=1.5,markersize=8,
                          zorder=5,elinewidth=1.5,markeredgecolor='black',markeredgewidth=1)
        if i==0: legend_handles.append((mark,'Conformal Hybrid'))

    # Unrestricted hybrid
    if has_unrestricted:
        for i,a in enumerate(alphas):
            x=unrestricted_cost_avg[a]; xerr=unrestricted_cost_std[a]
            y=unrestricted_avg[a]; yerr=unrestricted_std[a]
            mark=plt.errorbar(x,y,xerr=xerr,yerr=yerr,fmt='s',color=colors[i],capsize=5,capthick=1.5,markersize=8,
                              zorder=5,elinewidth=1.5,markeredgecolor='black',markeredgewidth=1)
            if i==0: legend_handles.append((mark,'Unrestricted Hybrid'))

    # Random baseline
    if has_random:
        for i,a in enumerate(alphas):
            if a in random_avg:
                x=random_cost[a]; y=random_avg[a]; yerr=random_std[a]
                xerr=random_cost_sd[a]; hoff=xerr*0.6
                ytop=y+yerr; ybot=y-yerr
                xl=x-hoff; xr=x+hoff
                plt.plot(x,y,'D',color=colors[i],markersize=8,markerfacecolor='white',markeredgewidth=1.5,
                         markeredgecolor='black',zorder=5)
                plt.hlines(y,xl,xr,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xl,ybot,y,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xr,y,ytop,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                cap=0.00001
                plt.plot([xl-cap,xl+cap],[ybot,ybot],'-',color=colors[i],linewidth=1.5)
                plt.plot([xr-cap,xr+cap],[ytop,ytop],'-',color=colors[i],linewidth=1.5)
                if i==0:
                    d=Line2D([],[],marker='D',color=colors[i],linestyle='--',markerfacecolor='white',
                              markeredgecolor='black',markersize=8,markeredgewidth=1.5)
                    legend_handles.append((d,'Random Baseline'))

    # Unrestricted random baseline
    if has_unres_rand:
        for i,a in enumerate(alphas):
            if a in unres_rand_avg:
                x=unres_rand_cost[a]; y=unres_rand_avg[a]; yerr=unres_rand_std[a]
                xerr=unres_rand_cost_sd[a]; hoff=xerr*0.6
                ytop=y+yerr; ybot=y-yerr
                xl=x-hoff; xr=x+hoff
                plt.plot(x,y,'^',color=colors[i],markersize=8,markerfacecolor='white',markeredgewidth=1.5,
                         markeredgecolor='black',zorder=5)
                plt.hlines(y,xl,xr,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xl,ybot,y,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xr,y,ytop,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                cap=0.00001
                plt.plot([xl-cap,xl+cap],[ybot,ybot],'-',color=colors[i],linewidth=1.5)
                plt.plot([xr-cap,xr+cap],[ytop,ytop],'-',color=colors[i],linewidth=1.5)
                if i==0:
                    t=Line2D([],[],marker='^',color=colors[i],linestyle='--',markerfacecolor='white',
                              markeredgecolor='black',markersize=8,markeredgewidth=1.5)
                    legend_handles.append((t,'Unrestricted Random Baseline'))

    # Baseline markers - FLIPPED COLORS HERE
    plt.errorbar(small_cost_avg,small_avg,yerr=small_std,xerr=small_cost_std,fmt='*',color='red',
                 markeredgecolor='black',markeredgewidth=1.5,capsize=5,markersize=15,zorder=10,label='Small Model Only')
    plt.errorbar(large_cost_avg,large_avg,yerr=large_std,xerr=large_cost_std,fmt='*',color='blue',
                 markeredgecolor='black',markeredgewidth=1.5,capsize=5,markersize=15,zorder=10,label='Large Model Only')

    # Labels & legend
    plt.xlabel('Average Cost per Example ($)',fontsize=14)
    plt.ylabel('Accuracy',fontsize=14)
# Fixed code for plot_cost_vs_accuracy_unrestricted legend
    custom_handles=[Line2D([],[],marker='*',color='red',linestyle='None',markersize=15,markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],marker='*',color='blue',linestyle='None',markersize=15,markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],marker='o',color=colors[0],linestyle='None',markersize=8,markeredgecolor='black',markeredgewidth=1),
                    Line2D([],[],marker='s',color=colors[0],linestyle='None',markersize=8,markeredgecolor='black',markeredgewidth=1),
                    Line2D([],[],marker='D',color=colors[0],linestyle='--',markerfacecolor='white',markersize=8,
                           markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],marker='^',color=colors[0],linestyle='--',markerfacecolor='white',markersize=8,
                           markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],color='gray',linestyle='--',alpha=0.6,linewidth=1.5)]
    custom_labels=['Primary','Guardian','Conformal Arbitrage','Calibrated Full Routing','Random Routing','Unrestricted Random Routing', 'Linear Interpolation']

    # Create a mapping from old labels to new labels
    label_mapping = {
        'Small Model Only': 'Primary',
        'Large Model Only': 'Guardian',
        'Conformal Hybrid': 'Conformal Arbitrage',
        'Unrestricted Hybrid': 'Calibrated Full Routing',
        'Random Baseline': 'Random Routing',
        'Unrestricted Random Baseline': 'Unrestricted Random Routing',
        'Linear Cost Interpolation': 'Linear Interpolation'
    }

    used=[lbl for _,lbl in legend_handles]
    handles,labels=[],[]

    # For each custom handle and label
    for h, l in zip(custom_handles, custom_labels):
        # Find if there's a corresponding old label that was used
        is_used = False
        for old_label, new_label in label_mapping.items():
            if old_label in used and new_label == l:
                is_used = True
                break
        
        # Also include it if it's one of the baseline models or interpolation
        if is_used or l in ['Primary', 'Guardian', 'Linear Interpolation']:
            handles.append(h)
            labels.append(l)

    plt.legend(handles, labels, loc='lower right', fontsize=12, frameon=True)
    plt.grid(True,alpha=0.3)

    fname=f"unrestricted_{'localized' if use_kernel else 'standard'}_cost_vs_accuracy_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir,fname),dpi=300,bbox_inches='tight')
    plt.close()
    return fname


def plot_enhanced_performance_unrestricted(final_results: Dict[str, Any], 
                                          use_kernel: bool = False, 
                                          output_dir: Optional[str] = None):
    """
    Create an enhanced plot of conformal alignment results showing both standard and unrestricted models.
    - Red/blue color palette matching simple_plotting.py
    - No individual trial dots
    - Asymmetric dashed error bars for baselines
    - Both types of hybrid models shown
    - Clean legend formatting
    """
    if output_dir is None:
        output_dir = ""

    alphas = list(final_results["hybrid_models"].keys())
    small_avg = final_results["small_model"]["avg_accuracy"]
    small_std = final_results["small_model"]["std_accuracy"]
    small_x = final_results.get("small_model",{}).get("x_position",1.0)
    small_x_std = final_results.get("small_model",{}).get("x_std",0.0)
    large_avg = final_results["large_model"]["avg_accuracy"]
    large_std = final_results["large_model"]["std_accuracy"]
    large_x = final_results.get("large_model",{}).get("x_position",0.0)
    large_x_std = final_results.get("large_model",{}).get("x_std",0.0)
    hybrid_avg={a:final_results["hybrid_models"][a]["avg_accuracy"] for a in alphas}
    hybrid_std={a:final_results["hybrid_models"][a]["std_accuracy"] for a in alphas}
    large_usage={a:final_results["hybrid_models"][a]["avg_large_model_usage"] for a in alphas}
    large_usage_std={a:final_results["hybrid_models"][a]["std_large_model_usage"] for a in alphas}
    small_usage={a:1.0-large_usage[a] for a in alphas}
    small_usage_std=large_usage_std
    has_unrestricted="unrestricted_hybrid" in final_results
    if has_unrestricted:
        unres_avg={a:final_results["unrestricted_hybrid"][a]["avg_accuracy"] for a in alphas}
        unres_std={a:final_results["unrestricted_hybrid"][a]["std_accuracy"] for a in alphas}
    has_random="random_baseline" in final_results
    if has_random:
        rand_avg={a:final_results["random_baseline"][a]["avg_accuracy"] for a in alphas if a in final_results["random_baseline"]}
        rand_std={a:final_results["random_baseline"][a]["std_accuracy"] for a in alphas if a in final_results["random_baseline"]}
    has_unres_rand="unrestricted_random_baseline" in final_results
    if has_unres_rand:
        ur_avg={a:final_results["unrestricted_random_baseline"][a]["avg_accuracy"] for a in alphas if a in final_results["unrestricted_random_baseline"]}
        ur_std={a:final_results["unrestricted_random_baseline"][a]["std_accuracy"] for a in alphas if a in final_results["unrestricted_random_baseline"]}

    plt.figure(figsize=(11,9))
    # b/w scheme
    n=len(alphas);n_low=n//2;n_high=n-n_low
    lt=np.linspace(0.1,0.3,n_low,endpoint=False)
    ht=np.linspace(0.7,0.9,n_high,endpoint=True)
    ts=np.concatenate([lt,ht])
    colors=plt.cm.bwr(ts)
    for i,a in enumerate(alphas):
        base=1-float(a)
        m=large_avg*base; hi=(large_avg+large_std)*base; lo=(large_avg-large_std)*base
        pct=float(a)*100
        plt.axhline(m,color=colors[i],linestyle=':',alpha=0.7)
        plt.axhspan(lo,hi,color=colors[i],alpha=0.2)
        plt.annotate(f'α={a} (-{pct:.0f}%)',xy=(0.01,m),xycoords=('axes fraction','data'),va='center',ha='left',
                     color=colors[i],fontweight='bold',bbox=dict(facecolor='white',alpha=0.9,boxstyle='round,pad=0.2',edgecolor=colors[i]))
    # shaded between
    xint=[large_x,small_x]; yint=[large_avg,small_avg]
    yup=[large_avg+large_std,small_avg+small_std]; ydn=[large_avg-large_std,small_avg-small_std]
    plt.plot(xint,yint,'--',color='gray',alpha=0.6,linewidth=1.5,label='Linear Interpolation')
    plt.fill_between(xint,ydn,yup,color='gray',alpha=0.15)
    # hybrid
    legend_handles=[]
    for i,a in enumerate(alphas):
        x=small_usage[a]; xerr=small_usage_std[a]
        y=hybrid_avg[a]; yerr=hybrid_std[a]
        m=plt.errorbar(x,y,xerr=xerr,yerr=yerr,fmt='o',color=colors[i],capsize=5,capthick=1.5,markersize=8,
                       zorder=5,elinewidth=1.5,markeredgecolor='black',markeredgewidth=1)
        if i==0: legend_handles.append((m,'Conformal Arbitrage'))
    if has_unrestricted:
        for i,a in enumerate(alphas):
            x=small_usage[a];xerr=small_usage_std[a]
            y=unres_avg[a];yerr=unres_std[a]
            m=plt.errorbar(x,y,xerr=xerr,yerr=yerr,fmt='s',color=colors[i],capsize=5,capthick=1.5,markersize=8,
                           zorder=5,elinewidth=1.5,markeredgecolor='black',markeredgewidth=1)
            if i==0: legend_handles.append((m,'Calibrated Full Routing'))
    # random
    if has_random:
        for i,a in enumerate(alphas):
            if a in rand_avg:
                x=small_usage[a];xerr=small_usage_std[a]; y=rand_avg[a]; yerr=rand_std[a]
                hoff=xerr*0.6; ytop=y+yerr; ybot=y-yerr; xl=x-hoff; xr=x+hoff
                plt.plot(x,y,'D',color=colors[i],markersize=8,markerfacecolor='white',markeredgewidth=1.5,
                         markeredgecolor='black',zorder=5)
                plt.hlines(y,xl,xr,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xl,ybot,y,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xr,y,ytop,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                cap=0.005
                plt.plot([xl-cap,xl+cap],[ybot,ybot],'-',color=colors[i],linewidth=1.5)
                plt.plot([xr-cap,xr+cap],[ytop,ytop],'-',color=colors[i],linewidth=1.5)
                if i==0:
                    d=Line2D([],[],marker='D',color=colors[i],linestyle='--',markerfacecolor='white',
                              markeredgecolor='black',markersize=8,markeredgewidth=1.5)
                    legend_handles.append((d,'Random Routing'))
    # unrestricted random
    if has_unres_rand:
        for i,a in enumerate(alphas):
            if a in ur_avg:
                x=small_usage[a];xerr=small_usage_std[a]; y=ur_avg[a]; yerr=ur_std[a]
                hoff=xerr*0.6; ytop=y+yerr; ybot=y-yerr; xl=x-hoff; xr=x+hoff
                plt.plot(x,y,'^',color=colors[i],markersize=8,markerfacecolor='white',markeredgewidth=1.5,
                         markeredgecolor='black',zorder=5)
                plt.hlines(y,xl,xr,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xl,ybot,y,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                plt.vlines(xr,y,ytop,colors=colors[i],linestyles='dashed',linewidth=1.5,zorder=4)
                cap=0.005
                plt.plot([xl-cap,xl+cap],[ybot,ybot],'-',color=colors[i],linewidth=1.5)
                plt.plot([xr-cap,xr+cap],[ytop,ytop],'-',color=colors[i],linewidth=1.5)
                if i==0:
                    t=Line2D([],[],marker='^',color=colors[i],linestyle='--',markerfacecolor='white',
                              markeredgecolor='black',markersize=8,markeredgewidth=1.5)
                    legend_handles.append((t,'Unrestricted Random Routing'))
    # baseline horizontals - changed colors here
    plt.axhline(large_avg,color='blue',linestyle='--',alpha=0.4,linewidth=1.5)
    plt.axhline(small_avg,color='red',linestyle='--',alpha=0.4,linewidth=1.5)
    # baseline markers - FLIPPED COLORS HERE
    plt.errorbar(large_x,large_avg,yerr=large_std,xerr=large_x_std,fmt='*',color='blue',
                 markeredgecolor='black',markeredgewidth=1.5,capsize=5,markersize=15,zorder=10,label='Guardian')
    plt.errorbar(small_x,small_avg,yerr=small_std,xerr=small_x_std,fmt='*',color='red',
                 markeredgecolor='black',markeredgewidth=1.5,capsize=5,markersize=15,zorder=10,label='Primary')
    plt.xlabel('Fraction of Small Model Calls',fontsize=14)
    plt.ylabel('Accuracy',fontsize=14)
    
    # Updated custom handles with new labels
    custom_handles=[Line2D([],[],marker='*',color='red',linestyle='None',markersize=15,markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],marker='*',color='blue',linestyle='None',markersize=15,markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],marker='o',color=colors[0],linestyle='None',markersize=8,markeredgecolor='black',markeredgewidth=1),
                    Line2D([],[],marker='s',color=colors[0],linestyle='None',markersize=8,markeredgecolor='black',markeredgewidth=1),
                    Line2D([],[],marker='D',color=colors[0],linestyle='--',markerfacecolor='white',markersize=8,
                           markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],marker='^',color=colors[0],linestyle='--',markerfacecolor='white',markersize=8,
                           markeredgecolor='black',markeredgewidth=1.5),
                    Line2D([],[],color='gray',linestyle='--',alpha=0.6,linewidth=1.5)]
    custom_labels=['Primary','Guardian','Conformal Arbitrage','Calibrated Full Routing','Random Routing',
                   'Unrestricted Random Routing','Linear Interpolation']
    
    # Map old labels to new ones for checking what to include
    label_mapping = {
        'Small Model Only': 'Primary',
        'Large Model Only': 'Guardian',
        'Conformal Hybrid': 'Conformal Arbitrage',
        'Unrestricted Hybrid': 'Calibrated Full Routing',
        'Random Baseline': 'Random Routing',
        'Unrestricted Random Baseline': 'Unrestricted Random Routing',
        'Linear Interpolation': 'Linear Interpolation'
    }
    
    # Get what was actually used in the plot
    used = [lbl for _, lbl in legend_handles]
    handles, labels = [], []
    
    # Simple approach - just use the handles directly without complex filtering
    plt.legend(custom_handles, custom_labels, loc='upper right', fontsize=12, frameon=True)
    
    plt.grid(True,alpha=0.3)
    plt.xlim(-0.05,1.05)
    ymins=min(small_avg-small_std*2,min(hybrid_avg.values())-max(hybrid_std.values())*2)
    yaxs=max(large_avg+large_std*2,max(hybrid_avg.values())+max(hybrid_std.values())*2)
    plt.ylim(ymins,yaxs)
    fname=f"unrestricted_{'localized' if use_kernel else 'standard'}_enhanced_performance_trials.png"
    plt.savefig(os.path.join(output_dir,fname),dpi=300,bbox_inches='tight')
    plt.close()
    return fname

def plot_lambda_vs_alpha(final_results: Dict[str, Any], 
                        use_kernel: bool = False, 
                        output_dir: Optional[str] = None):
    """
    Create a simpler plot showing lambda threshold vs alpha with darker colors.
    """
    if output_dir is None:
        output_dir = ""
    alphas = list(final_results["hybrid_models"].keys())
    lambda_avg = {a: final_results["hybrid_models"][a]["avg_lambda"] for a in alphas}
    lambda_std = {a: final_results["hybrid_models"][a]["std_lambda"] for a in alphas}
    iters = final_results["iterations"]
    plt.figure(figsize=(8,6))
    plt.errorbar([float(a) for a in alphas], [lambda_avg[a] for a in alphas], yerr=[lambda_std[a] for a in alphas],
                fmt='o-', color='blue', capsize=5, markersize=10, linewidth=2, markeredgecolor='black', markeredgewidth=1)
    plt.grid(True,alpha=0.3)
    plt.xlabel('Alpha (Risk Level)',fontsize=14)
    plt.ylabel('Lambda Threshold',fontsize=14)
    tag='localized' if use_kernel else 'standard'
    fname=f"unrestricted_lambda_vs_alpha_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir,fname),dpi=300,bbox_inches='tight')
    plt.close()
    return fname

def plot_conformal_set_size_vs_alpha(final_results: Dict[str, Any], 
                                     use_kernel: bool = False, 
                                     output_dir: Optional[str] = None):
    """
    Create a plot showing the average size of the conformal set for each alpha value.
    If direct data is not available, we'll calculate it from the scored examples files.
    """
    if output_dir is None:
        output_dir = ""
    
    alphas = list(final_results["hybrid_models"].keys())
    iters = final_results["iterations"]
    
    # Check for directly available data first
    if "avg_conformal_set_size" in final_results["hybrid_models"][alphas[0]]:
        confset_avg = {a: final_results["hybrid_models"][a]["avg_conformal_set_size"] for a in alphas}
        confset_std = {a: final_results["hybrid_models"][a].get("std_conformal_set_size", 0) for a in alphas}
        print("Using direct conformal set size data")
    elif "conformal_set_stats" in final_results["hybrid_models"][alphas[0]]:
        confset_avg = {a: final_results["hybrid_models"][a]["conformal_set_stats"]["avg_size"] for a in alphas}
        confset_std = {a: final_results["hybrid_models"][a]["conformal_set_stats"].get("std_size", 0) for a in alphas}
        print("Using conformal set stats data")
    else:
        # Need to estimate from scored examples files
        print("Looking for scored examples files...")
        
        # Try to find scored examples files
        subject = final_results.get("subject", "")
        method_name = 'localized' if use_kernel else 'standard'
        
        # Look for files like "scored_examples_trial_1.json"
        trial_data = {}
        for i in range(1, iters+1):
            filename = f"scored_examples_trial_{i}.json"
            if os.path.exists(os.path.join(output_dir, filename)):
                try:
                    with open(os.path.join(output_dir, filename), "r", encoding="utf-8") as f:
                        trial_data[i] = json.load(f)
                        print(f"Loaded {filename}")
                except Exception as e:
                    print(f"Error loading {filename}: {e}")
        
        if trial_data:
            print(f"Calculating conformal set sizes from {len(trial_data)} trial files")
            confset_avg = {a: 0 for a in alphas}
            confset_std = {a: 0 for a in alphas}
            
            # For each alpha, calculate conformal set sizes
            for a in alphas:
                all_trial_avg_sizes = []
                
                # Get lambda threshold for this alpha
                lambda_threshold = final_results["hybrid_models"][a]["avg_lambda"]
                
                # Process each trial
                for trial_num, trial in trial_data.items():
                    example_set_sizes = []
                    
                    # Process each example
                    for example in trial.get("examples", []):
                        if "small_scores" in example and "choices" in example:
                            scores = example["small_scores"]
                            max_score = max(scores)
                            num_choices = len(example["choices"])
                            
                            # Count choices in conformal set
                            conformal_set = [i for i, score in enumerate(scores) 
                                            if score >= max_score - lambda_threshold]
                            example_set_sizes.append(len(conformal_set))
                    
                    if example_set_sizes:
                        trial_avg = sum(example_set_sizes) / len(example_set_sizes)
                        all_trial_avg_sizes.append(trial_avg)
                
                # Calculate average and std dev across trials
                if all_trial_avg_sizes:
                    confset_avg[a] = sum(all_trial_avg_sizes) / len(all_trial_avg_sizes)
                    confset_std[a] = np.std(all_trial_avg_sizes) if len(all_trial_avg_sizes) > 1 else 0
        
        # If no trial data or couldn't calculate, use a simpler estimation
        if not trial_data:
            print("Falling back to simple lambda threshold estimation")
            # Try to get number of choices from the first example
            num_choices = 4  # Default assumption
            
            for a in alphas:
                lambda_val = final_results["hybrid_models"][a]["avg_lambda"]
                # Rough estimate: as lambda increases, more items in conformal set
                est_confset_size = 1 + (num_choices - 1) * min(lambda_val * 2, 1.0)  # Scale factor for better estimates
                confset_avg[a] = est_confset_size
                confset_std[a] = 0  # No std dev data for this estimation
    
    # Convert to lists for plotting, sorted by alpha
    alpha_values = [float(a) for a in alphas]
    avg_values = [confset_avg[a] for a in alphas]
    std_values = [confset_std[a] for a in alphas]
    
    # Sort by alpha value
    sorted_data = sorted(zip(alpha_values, avg_values, std_values))
    alpha_values = [x[0] for x in sorted_data]
    avg_values = [x[1] for x in sorted_data]
    std_values = [x[2] for x in sorted_data]
    
    plt.figure(figsize=(8,6))
    
    # Plot with error bars
    plt.errorbar(alpha_values, avg_values, yerr=std_values,
                 fmt='o-', color='purple', capsize=5, markersize=10, linewidth=2, 
                 markeredgecolor='black', markeredgewidth=1)
    
    plt.grid(True, alpha=0.3)
    plt.xlabel('Alpha (Risk Level)', fontsize=14)
    plt.ylabel('Average Conformal Set Size', fontsize=14)
    plt.title('Conformal Set Size vs. Alpha', fontsize=16)
    
    # Add a horizontal line at y=1 to indicate minimum set size
    plt.axhline(y=1, color='gray', linestyle='--', alpha=0.6)
    plt.annotate('Minimum set size', xy=(0.01, 1.05), xycoords=('axes fraction', 'data'),
                 va='center', ha='left', color='gray', fontsize=10)
    
    # Add custom formatting
    plt.tight_layout()
    tag = 'localized' if use_kernel else 'standard'
    fname = f"conformal_set_size_{tag}_{iters}_trials.png"
    plt.savefig(os.path.join(output_dir, fname), dpi=300, bbox_inches='tight')
    plt.close()
    return fname

def calculate_and_print_conformal_set_sizes(final_results: Dict[str, Any], output_dir: str):
    """
    Calculates conformal set sizes from all available trial data and prints them for tables.
    Includes standard deviation across trials.
    """
    alphas = list(final_results["hybrid_models"].keys())
    
    # Define the mapping for alpha to lambda thresholds
    lambda_values = {a: final_results["hybrid_models"][a]["avg_lambda"] for a in alphas}
    
    print("\n----- CALCULATING ACTUAL CONFORMAL SET SIZES -----")
    
    # Find all trial files
    trial_files = []
    for i in range(1, 31):  # Try up to 30 trials
        filename = f"scored_examples_trial_{i}.json"
        if os.path.exists(os.path.join(output_dir, filename)):
            trial_files.append(filename)
    
    if not trial_files:
        print("No trial files found. Cannot calculate conformal set sizes.")
        return
    
    print(f"Found {len(trial_files)} trial files")
    
    # For each alpha, collect set sizes across all trials
    alpha_set_sizes = {a: [] for a in alphas}
    
    # Process each trial file
    for filename in trial_files:
        try:
            with open(os.path.join(output_dir, filename), "r", encoding="utf-8") as f:
                trial_data = json.load(f)
                
                # For each alpha and trial, calculate average set size
                for a in alphas:
                    lambda_threshold = lambda_values[a]
                    
                    # Calculate conformal set sizes for all examples in this trial
                    trial_sizes = []
                    for example in trial_data.get("examples", []):
                        if "small_scores" in example and "choices" in example:
                            scores = example["small_scores"]
                            max_score = max(scores)
                            
                            # Count choices in conformal set
                            conformal_set = [i for i, score in enumerate(scores) 
                                           if score >= max_score - lambda_threshold]
                            trial_sizes.append(len(conformal_set))
                    
                    # Calculate average for this trial and add to the collection
                    if trial_sizes:
                        trial_avg = sum(trial_sizes) / len(trial_sizes)
                        alpha_set_sizes[a].append(trial_avg)
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")
    
    # Calculate overall averages and standard deviations
    conf_set_avg = {}
    conf_set_std = {}
    
    for a in alphas:
        if alpha_set_sizes[a]:
            conf_set_avg[a] = sum(alpha_set_sizes[a]) / len(alpha_set_sizes[a])
            conf_set_std[a] = np.std(alpha_set_sizes[a]) if len(alpha_set_sizes[a]) > 1 else 0
        else:
            conf_set_avg[a] = 0
            conf_set_std[a] = 0
    
    # Print the results
    print("\n----- CONFORMAL SET SIZES (ACTUAL) -----")
    print("Alpha\tLambda\tSet_Size\tStd_Dev")
    
    for a in sorted([float(a) for a in alphas]):
        a_str = str(a)
        lambda_val = lambda_values[a_str]
        avg_size = conf_set_avg[a_str]
        std_size = conf_set_std[a_str]
        print(f"{a:.2f}\t{lambda_val:.4f}\t{avg_size:.4f}\t{std_size:.4f}")

# def create_unrestricted_plots(results_dir: str, use_kernel: bool=False):
#     """
#     Create publication-quality plots from existing results that always include unrestricted hybrids.
#     """
#     method_name='localized' if use_kernel else 'standard'
#     final_files=[f for f in os.listdir(results_dir) if (f.startswith("mmlu_") or f.startswith("truthqa_")) and f"_{method_name}_final_results_" in f and f.endswith(".json")]
#     if not final_files:
#         print(f"No final results in {results_dir} for {method_name}")
#         return None
#     final_file=sorted(final_files)[-1]
#     print(f"Loading {final_file}")
#     try:
#         with open(os.path.join(results_dir,final_file),"r",encoding="utf-8") as f:
#             final_results=json.load(f)
#     except UnicodeDecodeError:
#         with open(os.path.join(results_dir,final_file),"r",encoding="latin-1") as f:
#             final_results=json.load(f)
#     if "unrestricted_hybrid" not in final_results:
#         print("Warning: no unrestricted hybrid data.")
    
#     # Generate the plots
#     cost=plot_cost_vs_accuracy_unrestricted(final_results,use_kernel,results_dir)
#     perf=plot_enhanced_performance_unrestricted(final_results,use_kernel,results_dir)
#     lam=plot_lambda_vs_alpha(final_results,use_kernel,results_dir)
    
#     # Try to generate conformal set size plot
#     try:
#         confset=plot_conformal_set_size_vs_alpha(final_results,use_kernel,results_dir)
#         print(f"Created: {cost}, {perf}, {lam}, {confset}")
#     except Exception as e:
#         confset = None
#         print(f"Created: {cost}, {perf}, {lam}")
#         print(f"Warning: Couldn't create conformal set size plot: {e}")
    
#     # Print numerical results in a format suitable for tables
#     print("\n===== NUMERICAL RESULTS FOR TABLES =====")
    
#     # Sort alphas for consistent presentation
#     alphas = sorted([float(a) for a in final_results["hybrid_models"].keys()])
    
#     # Print small model usage data (from enhanced performance plot)
#     print("\n----- FRACTION OF SMALL MODEL CALLS -----")
#     print("Alpha\tSmall_Frac\tStd_Dev\tAccuracy\tStd_Dev")
#     for a in alphas:
#         a_str = str(a)
#         small_usage = 1.0 - final_results["hybrid_models"][a_str]["avg_large_model_usage"]
#         small_usage_std = final_results["hybrid_models"][a_str]["std_large_model_usage"]
#         accuracy = final_results["hybrid_models"][a_str]["avg_accuracy"]
#         accuracy_std = final_results["hybrid_models"][a_str]["std_accuracy"]
#         print(f"{a:.2f}\t{small_usage:.4f}\t{small_usage_std:.4f}\t{accuracy:.4f}\t{accuracy_std:.4f}")
    
#     # Add baseline models
#     small_acc = final_results["small_model"]["avg_accuracy"]
#     small_std = final_results["small_model"]["std_accuracy"]
#     large_acc = final_results["large_model"]["avg_accuracy"]
#     large_std = final_results["large_model"]["std_accuracy"]
#     print(f"Small\t1.0000\t0.0000\t{small_acc:.4f}\t{small_std:.4f}")
#     print(f"Large\t0.0000\t0.0000\t{large_acc:.4f}\t{large_std:.4f}")
    
#     # Unrestricted hybrid data if available
#     if "unrestricted_hybrid" in final_results:
#         print("\n----- UNRESTRICTED HYBRID ACCURACY -----")
#         print("Alpha\tAccuracy\tStd_Dev")
#         for a in alphas:
#             a_str = str(a)
#             if a_str in final_results["unrestricted_hybrid"]:
#                 accuracy = final_results["unrestricted_hybrid"][a_str]["avg_accuracy"]
#                 accuracy_std = final_results["unrestricted_hybrid"][a_str]["std_accuracy"]
#                 print(f"{a:.2f}\t{accuracy:.4f}\t{accuracy_std:.4f}")
    
#     # Print Lambda thresholds
#     print("\n----- LAMBDA THRESHOLDS -----")
#     print("Alpha\tLambda\tStd_Dev")
#     for a in alphas:
#         a_str = str(a)
#         lambda_val = final_results["hybrid_models"][a_str]["avg_lambda"]
#         lambda_std = final_results["hybrid_models"][a_str]["std_lambda"]
#         print(f"{a:.2f}\t{lambda_val:.4f}\t{lambda_std:.4f}")
    
#     # Print conformal set size data
#     print("\n----- CONFORMAL SET SIZES -----")
#     print("Alpha\tSet_Size\tStd_Dev")
    
#     # First try to find direct data
#     has_direct_data = False
#     if alphas and str(alphas[0]) in final_results["hybrid_models"]:
#         if "avg_conformal_set_size" in final_results["hybrid_models"][str(alphas[0])]:
#             has_direct_data = True
#             for a in alphas:
#                 a_str = str(a)
#                 set_size = final_results["hybrid_models"][a_str]["avg_conformal_set_size"]
#                 set_std = final_results["hybrid_models"][a_str].get("std_conformal_set_size", 0)
#                 print(f"{a:.2f}\t{set_size:.4f}\t{set_std:.4f}")
#         elif "conformal_set_stats" in final_results["hybrid_models"][str(alphas[0])]:
#             has_direct_data = True
#             for a in alphas:
#                 a_str = str(a)
#                 set_size = final_results["hybrid_models"][a_str]["conformal_set_stats"]["avg_size"]
#                 set_std = final_results["hybrid_models"][a_str]["conformal_set_stats"].get("std_size", 0)
#                 print(f"{a:.2f}\t{set_size:.4f}\t{set_std:.4f}")
    
#     # If no direct data, calculate from lambda thresholds
#     if not has_direct_data:
#         print("(Estimated from lambda thresholds)")
#         num_choices = 4  # Default assumption
#         for a in alphas:
#             a_str = str(a)
#             lambda_val = final_results["hybrid_models"][a_str]["avg_lambda"]
#             # Estimate conformal set size
#             est_confset_size = 1 + (num_choices - 1) * min(lambda_val * 2, 1.0)
#             print(f"{a:.2f}\t{est_confset_size:.4f}\t0.0000")
    
#     print("\n========================================")
    
#     return final_results

def create_unrestricted_plots(results_dir: str, use_kernel: bool=False):
    """
    Create publication-quality plots from existing results that always include unrestricted hybrids.
    """
    method_name='localized' if use_kernel else 'standard'
    final_files=[f for f in os.listdir(results_dir) if (f.startswith("mmlu_") or f.startswith("truthqa_")) and f"_{method_name}_final_results_" in f and f.endswith(".json")]
    if not final_files:
        print(f"No final results in {results_dir} for {method_name}")
        return None
    final_file=sorted(final_files)[-1]
    print(f"Loading {final_file}")
    try:
        with open(os.path.join(results_dir,final_file),"r",encoding="utf-8") as f:
            final_results=json.load(f)
    except UnicodeDecodeError:
        with open(os.path.join(results_dir,final_file),"r",encoding="latin-1") as f:
            final_results=json.load(f)
    if "unrestricted_hybrid" not in final_results:
        print("Warning: no unrestricted hybrid data.")
    
    # Generate the plots
    cost=plot_cost_vs_accuracy_unrestricted(final_results,use_kernel,results_dir)
    perf=plot_enhanced_performance_unrestricted(final_results,use_kernel,results_dir)
    lam=plot_lambda_vs_alpha(final_results,use_kernel,results_dir)
    
    # Print numerical results in a format suitable for tables
    print("\n===== NUMERICAL RESULTS FOR TABLES =====")
    
    # Sort alphas for consistent presentation
    alphas = sorted([float(a) for a in final_results["hybrid_models"].keys()])
    
    # Print small model usage data (from enhanced performance plot)
    print("\n----- FRACTION OF SMALL MODEL CALLS -----")
    print("Alpha\tSmall_Frac\tStd_Dev\tAccuracy\tStd_Dev")
    for a in alphas:
        a_str = str(a)
        small_usage = 1.0 - final_results["hybrid_models"][a_str]["avg_large_model_usage"]
        small_usage_std = final_results["hybrid_models"][a_str]["std_large_model_usage"]
        accuracy = final_results["hybrid_models"][a_str]["avg_accuracy"]
        accuracy_std = final_results["hybrid_models"][a_str]["std_accuracy"]
        print(f"{a:.2f}\t{small_usage:.4f}\t{small_usage_std:.4f}\t{accuracy:.4f}\t{accuracy_std:.4f}")
    
    # Add baseline models
    small_acc = final_results["small_model"]["avg_accuracy"]
    small_std = final_results["small_model"]["std_accuracy"]
    large_acc = final_results["large_model"]["avg_accuracy"]
    large_std = final_results["large_model"]["std_accuracy"]
    print(f"Small\t1.0000\t0.0000\t{small_acc:.4f}\t{small_std:.4f}")
    print(f"Large\t0.0000\t0.0000\t{large_acc:.4f}\t{large_std:.4f}")
    
    # Unrestricted hybrid data if available
    if "unrestricted_hybrid" in final_results:
        print("\n----- UNRESTRICTED HYBRID ACCURACY -----")
        print("Alpha\tAccuracy\tStd_Dev")
        for a in alphas:
            a_str = str(a)
            if a_str in final_results["unrestricted_hybrid"]:
                accuracy = final_results["unrestricted_hybrid"][a_str]["avg_accuracy"]
                accuracy_std = final_results["unrestricted_hybrid"][a_str]["std_accuracy"]
                print(f"{a:.2f}\t{accuracy:.4f}\t{accuracy_std:.4f}")
    
    # Print Lambda thresholds
    print("\n----- LAMBDA THRESHOLDS -----")
    print("Alpha\tLambda\tStd_Dev")
    for a in alphas:
        a_str = str(a)
        lambda_val = final_results["hybrid_models"][a_str]["avg_lambda"]
        lambda_std = final_results["hybrid_models"][a_str]["std_lambda"]
        print(f"{a:.2f}\t{lambda_val:.4f}\t{lambda_std:.4f}")
    
    # Calculate and print actual conformal set sizes
    calculate_and_print_conformal_set_sizes(final_results, results_dir)
    
    print("\n========================================")
    
    return final_results

if __name__=="__main__":
    import argparse
    parser=argparse.ArgumentParser(description="Create plots with unrestricted hybrids")
    parser.add_argument("--results_dir",type=str,required=True,help="Dir with result files")
    parser.add_argument("--use_kernel",action="store_true",help="Use kernel-based method")
    args=parser.parse_args()
    create_unrestricted_plots(args.results_dir,args.use_kernel)