import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import os
from typing import Dict, List, Tuple, Optional
import json

plt.rcParams.update({
    'font.size': 33,
    'axes.labelsize': 39,
    'axes.titlesize': 43,
    'xtick.labelsize': 31,
    'ytick.labelsize': 31,
    'legend.fontsize': 34,
    'figure.titlesize': 47,
    'font.family': 'serif',
    'text.usetex': False,  # Set to True if LaTeX is available
})

sns.set_style("white", {
    'axes.edgecolor': '0.2',
    'axes.linewidth': 2.0,
})

COLORS = sns.color_palette("colorblind", n_colors=8)
AGENT_COLORS = {
    'agent1': COLORS[0],  # Blue
    'agent2': COLORS[1],  # Orange
    'theoretical': COLORS[2],  # Green
    'convergence': COLORS[3],  # Red
}




def create_axa_convergence_dynamics(
    agent1_distances: List[float],
    agent2_distances: List[float],
    angle_deg: float,
    theoretical_results: Optional[Dict] = None,
    convergence_step: Optional[int] = None,
    save_path: Optional[str] = None
) -> None:
    """
    Create AxA convergence dynamics plot.
    
    Args:
        agent1_distances: Distance from agent 1 to its optimal weight over time
        agent2_distances: Distance from agent 2 to its optimal weight over time  
        angle_deg: Angle between objective vectors in degrees
        theoretical_results: Dict containing theoretical plateau predictions
        convergence_step: Step where gradients converged (optional)
        save_path: Path to save figure (optional)
    """
    fig, ax = plt.subplots(figsize=(18, 12))
    
    steps = np.arange(1, len(agent1_distances) + 1)
    
    ax.plot(steps, agent1_distances, 
           color=AGENT_COLORS['agent1'], linewidth=4.0, 
           label=r'Agent $U$ distance to $u^*$', alpha=0.9)
    
    ax.plot(steps, agent2_distances, 
           color=AGENT_COLORS['agent2'], linewidth=4.0,
           label=r'Agent $W$ distance to $w^*$', alpha=0.9)
    
    if theoretical_results:
        plateau_agent1 = theoretical_results['plateau_error_agent1_boxed']
        plateau_agent2 = theoretical_results['plateau_error_agent2_boxed']
        
        ax.axhline(y=plateau_agent1, color=AGENT_COLORS['agent1'], 
                  linestyle='--', alpha=0.7, linewidth=3.0,
                  label=f'Theory: Agent $U$ plateau')
        
        ax.axhline(y=plateau_agent2, color=AGENT_COLORS['agent2'], 
                  linestyle='--', alpha=0.7, linewidth=3.0,
                  label=f'Theory: Agent $W$ plateau')
    
    ax.set_xlabel('Inference Step')
    ax.set_ylabel('Distance to Optimal Objective')
    all_distances = agent1_distances + agent2_distances
    if theoretical_results:
        all_distances.extend([
            theoretical_results['plateau_error_agent1_boxed'],
            theoretical_results['plateau_error_agent2_boxed']
        ])
    
    y_max = max(all_distances)
    
    y_min_start = 0
    
    extended_y_max = y_max * 1.1
    
    ax.set_ylim(y_min_start, extended_y_max)
    
    from matplotlib.ticker import FuncFormatter
    def format_func(value, tick_number):
        if value == 0:
            return '0'
        elif value >= 1:
            return f'{value:.1f}'
        elif value >= 0.1:
            return f'{value:.2f}'
        elif value >= 0.01:
            return f'{value:.3f}'
        else:
            return f'{value:.4f}'
    
    ax.yaxis.set_major_formatter(FuncFormatter(format_func))
    
    alignment_text = get_alignment_description(angle_deg)
    textstr = f'Angle between objectives: {angle_deg:.1f}°'
    
    props = dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8)
    ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=32,
            verticalalignment='top', bbox=props)
    
    if save_path and ('opposite' in str(save_path).lower() or 'orthogonal' in str(save_path).lower()):
        legend_loc = 'lower right'
    else:
        legend_loc = 'upper right'
    
    ax.legend(loc=legend_loc, frameon=True, fancybox=True, shadow=True)
    
    plt.tight_layout()
    
    if save_path:
        base_dir = os.path.dirname(save_path)
        filename = os.path.basename(save_path)
        base_filename = os.path.splitext(filename)[0]
        
        pdf_dir = os.path.join(base_dir, "pdf")
        png_dir = os.path.join(base_dir, "png")
        os.makedirs(pdf_dir, exist_ok=True)
        os.makedirs(png_dir, exist_ok=True)
        
        pdf_path = os.path.join(pdf_dir, f"{base_filename}.pdf")
        png_path = os.path.join(png_dir, f"{base_filename}.png")
        
        plt.savefig(pdf_path, dpi=300, bbox_inches='tight')
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        print(f"Saved  figure: {pdf_path} and {png_path}")
    
    plt.show()


def compute_eigenvalue_bounds_from_matrices(S_W, S_U, S):
    """
    Compute the exact eigenvalue parameters from the corollary when matrices are available.
    
    Args:
        S_W: Second moment matrix for agent W (shape: [d, d])
        S_U: Second moment matrix for agent U (shape: [d, d]) 
        S: Combined matrix S = S_W + S_U (shape: [d, d])
        
    Returns:
        Dictionary with α_U, β_U, α_W, β_W parameters
    """
    import torch
    
    S_inv = torch.pinverse(S)
    S_inv_squared = S_inv @ S_inv
    
    # Compute the matrices from the corollary
    M_U = S_W @ S_inv_squared @ S_W  # For α_U, β_U
    M_W = S_U @ S_inv_squared @ S_U  # For α_W, β_W
    
    # Compute eigenvalues
    eigenvals_U = torch.linalg.eigvals(M_U).real
    eigenvals_W = torch.linalg.eigvals(M_W).real
    
    # Extract min and max eigenvalues and take square root
    alpha_U = torch.sqrt(torch.min(eigenvals_U)).item()
    beta_U = torch.sqrt(torch.max(eigenvals_U)).item()
    alpha_W = torch.sqrt(torch.min(eigenvals_W)).item()
    beta_W = torch.sqrt(torch.max(eigenvals_W)).item()
    
    return {
        'alpha_U': alpha_U, 'beta_U': beta_U,
        'alpha_W': alpha_W, 'beta_W': beta_W
    }


def compute_corollary_bounds(angle_deg: float, w_star_norm: float, u_star_norm: float, 
                           alpha_U: float , beta_U: float , 
                           alpha_W: float , beta_W: float ) -> Dict:
    """
    Compute upper and lower bounds from Corollary (angle-only bounds).
    
    From the corollary:
    α_U·r_min(θ) ≤ ||u_∞-u*||₂/√(||w*||₂²+||u*||₂²) ≤ β_U·r_max(θ) + O(η)
    α_W·r_min(θ) ≤ ||w_∞-w*||₂/√(||w*||₂²+||u*||₂²) ≤ β_W·r_max(θ) + O(η)
    
    where:
    α_U = √(λ_min(S_W S^(-2) S_W)), β_U = √(λ_max(S_W S^(-2) S_W))
    α_W = √(λ_min(S_U S^(-2) S_U)), β_W = √(λ_max(S_U S^(-2) S_U))
    
    Args:
        angle_deg: Angle between objectives in degrees
        w_star_norm: Norm of w* (default: 1.0)
        u_star_norm: Norm of u* (default: 1.0)
        alpha_U, beta_U: Eigenvalue bounds for agent U (estimated if matrices not available)
        alpha_W, beta_W: Eigenvalue bounds for agent W (estimated if matrices not available)
        
    Returns:
        Dictionary with upper and lower bounds for both agents
    """
    import math
    
    # Convert angle to radians
    theta = math.radians(angle_deg)
    
    # Compute r_min and r_max functions
    sqrt_1_minus_cos_theta = math.sqrt(1 - math.cos(theta))
    r_min = min(1.0, sqrt_1_minus_cos_theta)
    r_max = max(1.0, sqrt_1_minus_cos_theta)
    
    # Normalization factor from corollary
    norm_factor = math.sqrt(w_star_norm**2 + u_star_norm**2)
    
    # Compute bounds for agent U (multiply by norm_factor since we're plotting ||u_∞-u*||₂)
    u_lower = alpha_U * r_min * norm_factor
    u_upper = beta_U * r_max * norm_factor
    
    # Compute bounds for agent W (multiply by norm_factor since we're plotting ||w_∞-w*||₂)
    w_lower = alpha_W * r_min * norm_factor
    w_upper = beta_W * r_max * norm_factor
    
    return {
        'agent_U_lower': u_lower,
        'agent_U_upper': u_upper,
        'agent_W_lower': w_lower,
        'agent_W_upper': w_upper,
        'r_min': r_min,
        'r_max': r_max
    }


def create_objective_alignment_study(
    alignment_data: List[Dict],
    save_path: Optional[str] = None
) -> None:
    """
    Create plot showing how objective alignment affects plateau errors.
    
    Args:
        alignment_data: List of dicts with keys: 'angle_deg', 'plateau_agent1', 'plateau_agent2'
        save_path: Path to save figure (optional)
    """
    fig, ax = plt.subplots(figsize=(18, 12))
    
    angles = [d['angle_deg'] for d in alignment_data]
    plateaus_agent1 = [d['plateau_agent1'] for d in alignment_data]
    plateaus_agent2 = [d['plateau_agent2'] for d in alignment_data]
    
    ax.scatter(angles, plateaus_agent1, 
              color=AGENT_COLORS['agent1'], s=120, alpha=0.8,
              label=r'Agent $U$ plateau error', edgecolors='black', linewidth=1.0)
    
    ax.scatter(angles, plateaus_agent2,
              color=AGENT_COLORS['agent2'], s=120, alpha=0.8, 
              label=r'Agent $W$ plateau error', edgecolors='black', linewidth=1.0)

    angle_grid = np.linspace(0, 180, 100)
    

    alpha_U_values = [d['alpha_U'] for d in alignment_data]
    beta_U_values = [d['beta_U'] for d in alignment_data]
    alpha_W_values = [d['alpha_W'] for d in alignment_data]
    beta_W_values = [d['beta_W'] for d in alignment_data]
    
    if 'w_star_norm' in alignment_data[0] and 'u_star_norm' in alignment_data[0]:
        w_star_norm_values = [d['w_star_norm'] for d in alignment_data]
        u_star_norm_values = [d['u_star_norm'] for d in alignment_data]
    
    alpha_U_min = np.min(alpha_U_values)
    alpha_U_max = np.max(alpha_U_values)
    beta_U_min = np.min(beta_U_values)
    beta_U_max = np.max(beta_U_values)
    alpha_W_min = np.min(alpha_W_values)
    alpha_W_max = np.max(alpha_W_values)
    beta_W_min = np.min(beta_W_values)
    beta_W_max = np.max(beta_W_values)
    w_star_norm_est = np.mean(w_star_norm_values)
    u_star_norm_est = np.mean(u_star_norm_values)
    
    print(f"Using exact eigenvalue parameter ranges from corollary:")
    print(f"  α_U range: [{alpha_U_min:.4f}, {alpha_U_max:.4f}]")
    print(f"  β_U range: [{beta_U_min:.4f}, {beta_U_max:.4f}]")
    print(f"  α_W range: [{alpha_W_min:.4f}, {alpha_W_max:.4f}]")
    print(f"  β_W range: [{beta_W_min:.4f}, {beta_W_max:.4f}]")
    print(f"  ||w*|| = {w_star_norm_est:.4f}")
    print(f"  ||u*|| = {u_star_norm_est:.4f}")
    

    # Compute bounds envelope using min/max eigenvalue ranges
    bounds_U_upper_max = []  # Most permissive upper bound
    bounds_U_lower_min = []  # Most restrictive lower bound  
    bounds_W_upper_max = []
    bounds_W_lower_min = []
    
    for angle in angle_grid:
        # For upper bounds: use maximum β values for most permissive bounds
        bounds_U_max = compute_corollary_bounds(angle, w_star_norm=w_star_norm_est, u_star_norm=u_star_norm_est,
                                              alpha_U=alpha_U_min, beta_U=beta_U_max,
                                              alpha_W=alpha_W_min, beta_W=beta_W_max)
        
        # For lower bounds: use minimum α values for most restrictive bounds  
        bounds_U_min = compute_corollary_bounds(angle, w_star_norm=w_star_norm_est, u_star_norm=u_star_norm_est,
                                              alpha_U=alpha_U_min, beta_U=beta_U_min,
                                              alpha_W=alpha_W_min, beta_W=beta_W_min)
        
        bounds_U_upper_max.append(bounds_U_max['agent_U_upper'])
        bounds_U_lower_min.append(bounds_U_min['agent_U_lower'])
        bounds_W_upper_max.append(bounds_U_max['agent_W_upper'])
        bounds_W_lower_min.append(bounds_U_min['agent_W_lower'])
    
    ax.plot(angle_grid, bounds_U_upper_max, '--', color=AGENT_COLORS['agent1'], alpha=0.6, linewidth=2.5)
    ax.plot(angle_grid, bounds_U_lower_min, '--', color=AGENT_COLORS['agent1'], alpha=0.6, linewidth=2.5,
           label=r'Agent $U$ bounds (Corollary)')
    ax.plot(angle_grid, bounds_W_upper_max, '--', color=AGENT_COLORS['agent2'], alpha=0.6, linewidth=2.5)
    ax.plot(angle_grid, bounds_W_lower_min, '--', color=AGENT_COLORS['agent2'], alpha=0.6, linewidth=2.5,
           label=r'Agent $W$ bounds (Corollary)')
    
    end_angle = angle_grid[-1]
    ax.text(end_angle + 2, bounds_U_upper_max[-1], 'upper', color=AGENT_COLORS['agent1'], 
           fontsize=20, va='center', ha='left', weight='bold')
    ax.text(end_angle + 2, bounds_U_lower_min[-1], 'lower', color=AGENT_COLORS['agent1'], 
           fontsize=20, va='center', ha='left', weight='bold')
    ax.text(end_angle + 2, bounds_W_upper_max[-1], 'upper', color=AGENT_COLORS['agent2'], 
           fontsize=20, va='center', ha='left', weight='bold')
    ax.text(end_angle + 2, bounds_W_lower_min[-1], 'lower', color=AGENT_COLORS['agent2'], 
           fontsize=20, va='center', ha='left', weight='bold')
    
    
    ax.set_xlabel('Objective Alignment Angle (degrees)')
    ax.set_ylabel('Distance to Optimal Objective')
    ax.set_xlim(-5, 195)  
    
    all_values = plateaus_agent1 + plateaus_agent2
    all_values.extend(bounds_U_upper_max)
    all_values.extend(bounds_U_lower_min) 
    all_values.extend(bounds_W_upper_max)
    all_values.extend(bounds_W_lower_min)
    
    y_max = max(all_values)
    y_min = min(min(all_values), 0) 
    
    margin = (y_max - y_min) * 0.25  
    extended_y_max = y_max + margin
    extended_y_min = max(0, y_min - margin * 0.4)
    
    ax.set_ylim(extended_y_min, extended_y_max)
    
    ax.axvspan(0, 30, alpha=0.1, color='green')
    ax.axvspan(60, 120, alpha=0.1, color='orange')  
    ax.axvspan(150, 180, alpha=0.1, color='red')
    
    ax.text(15, extended_y_max * 0.95, 'Aligned', ha='center', va='top', fontsize=24, 
           bbox=dict(boxstyle='round,pad=0.3', facecolor='green', alpha=0.3))
    ax.text(90, extended_y_max * 0.95, 'Orthogonal', ha='center', va='top', fontsize=24,
           bbox=dict(boxstyle='round,pad=0.3', facecolor='orange', alpha=0.3))
    ax.text(165, extended_y_max * 0.95, 'Opposite', ha='center', va='top', fontsize=24,
           bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.3))
    
    from matplotlib.ticker import FuncFormatter
    def format_func(value, tick_number):
        if value == 0:
            return '0'
        elif value >= 1:
            return f'{value:.1f}'
        elif value >= 0.1:
            return f'{value:.2f}'
        elif value >= 0.01:
            return f'{value:.3f}'
        else:
            return f'{value:.4f}'
    
    ax.yaxis.set_major_formatter(FuncFormatter(format_func))
    
    ax.legend(bbox_to_anchor=(0.02, 0.85), loc='upper left', frameon=True, fancybox=True, shadow=True, 
             fontsize=24, ncol=1, columnspacing=1.0, handletextpad=0.5)
    
    plt.tight_layout()
    
    if save_path:
        base_dir = os.path.dirname(save_path)
        filename = os.path.basename(save_path)
        base_filename = os.path.splitext(filename)[0]
        
        pdf_dir = os.path.join(base_dir, "pdf")
        png_dir = os.path.join(base_dir, "png")
        os.makedirs(pdf_dir, exist_ok=True)
        os.makedirs(png_dir, exist_ok=True)
        
        pdf_path = os.path.join(pdf_dir, f"{base_filename}.pdf")
        png_path = os.path.join(png_dir, f"{base_filename}.png")
        
        plt.savefig(pdf_path, dpi=300, bbox_inches='tight')
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        print(f"Saved  figure: {pdf_path} and {png_path}")
    
    plt.show()


def create_theoretical_validation(
    empirical_plateaus: List[float],
    theoretical_plateaus: List[float],
    agent_name: str = "U",
    save_path: Optional[str] = None
) -> None:
    """
    Create scatter plot validating theoretical predictions against empirical results.
    
    Args:
        empirical_plateaus: Measured plateau errors from experiments
        theoretical_plateaus: Predicted plateau errors from theory
        agent_name: Name of agent for labeling
        save_path: Path to save figure (optional)
    """
    fig, ax = plt.subplots(figsize=(14, 14))
    
    # Perfect correlation line
    min_val = min(min(empirical_plateaus), min(theoretical_plateaus))
    max_val = max(max(empirical_plateaus), max(theoretical_plateaus))
    ax.plot([min_val, max_val], [min_val, max_val], 
           'k--', alpha=0.5, linewidth=3, label='Perfect correlation')
    
    # Scatter plot
    ax.scatter(theoretical_plateaus, empirical_plateaus,
              color=AGENT_COLORS['agent1'] if agent_name == "U" else AGENT_COLORS['agent2'],
              s=150, alpha=0.8, edgecolors='black', linewidth=1.0)
    
    correlation = np.corrcoef(theoretical_plateaus, empirical_plateaus)[0, 1]
    r_squared = correlation ** 2
    
    textstr = f'$R^2 = {r_squared:.3f}$\n$n = {len(empirical_plateaus)}$'
    props = dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8)
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=31,
            verticalalignment='top', bbox=props)
    
    ax.set_xlabel(f'Theoretical Plateau Error (Agent {agent_name})')
    ax.set_ylabel(f'Empirical Plateau Error (Agent {agent_name})')
    ax.set_xscale('log')
    ax.set_yscale('log')
    
    from matplotlib.ticker import FuncFormatter
    def format_func(value, tick_number):
        if value >= 1:
            return f'{value:.1f}'
        elif value >= 0.1:
            return f'{value:.2f}'
        elif value >= 0.01:
            return f'{value:.3f}'
        else:
            return f'{value:.4f}'
    
    ax.xaxis.set_major_formatter(FuncFormatter(format_func))
    ax.yaxis.set_major_formatter(FuncFormatter(format_func))
    
    ax.legend(loc='lower right', frameon=True, fancybox=True, shadow=True)
    ax.grid(True, alpha=0.3)
    
    ax.set_aspect('equal', adjustable='box')
    
    plt.tight_layout()
    
    if save_path:
        base_dir = os.path.dirname(save_path)
        filename = os.path.basename(save_path)
        base_filename = os.path.splitext(filename)[0]
        
        pdf_dir = os.path.join(base_dir, "pdf")
        png_dir = os.path.join(base_dir, "png")
        os.makedirs(pdf_dir, exist_ok=True)
        os.makedirs(png_dir, exist_ok=True)
        
        pdf_path = os.path.join(pdf_dir, f"{base_filename}.pdf")
        png_path = os.path.join(png_dir, f"{base_filename}.png")
        
        plt.savefig(pdf_path, dpi=300, bbox_inches='tight')
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        print(f"Saved  figure: {pdf_path} and {png_path}")
    
    plt.show()




def get_alignment_description(angle_deg: float) -> str:
    """Get human-readable description of objective alignment."""
    if angle_deg < 15:
        return "Highly Aligned"
    elif angle_deg < 60:
        return "Moderately Aligned"
    elif angle_deg < 120:
        return "Weakly Aligned"
    else:
        return "Conflicting"


def load_experimental_data(data_path: str) -> Dict:
    """Load experimental data from JSON file."""
    with open(data_path, 'r') as f:
        return json.load(f)


def plot_adversarial_result(interaction_result: Dict, trial_num: int, R_type: str, save_path: str, agent_type: str = "model"):
    """
    Plot adversarial attack results with success indicators.
    
    Args:
        interaction_result: Results from adversarial interaction
        trial_num: Trial number for labeling
        R_type: Type of R matrix transformation
        save_path: Path to save the plot
        agent_type: Type of agents used ("model" or "openai") - affects success criteria
    """
    from adversarial_attack import check_attack_success
    
    fig = plt.figure(figsize=(20, 10))
    
    ax_main = plt.subplot(1, 2, 1)
    
    agent1_distances = interaction_result['agent1_distances']
    agent2_distances = interaction_result['agent2_distances']
    steps = range(1, len(agent1_distances) + 1)
    
    # Use  colors: Agent W (victim) and Agent U (attacker)
    ax_main.plot(steps, agent1_distances, color=AGENT_COLORS['agent2'], marker='o', 
                markersize=4, label=r'Agent $W$ (victim)', linewidth=3.0)
    ax_main.plot(steps, agent2_distances, color=AGENT_COLORS['agent1'], marker='s', 
                markersize=4, label=r'Agent $U$ (attacker)', linewidth=3.0)
    
    ax_main.set_xlabel('Inference Step')
    ax_main.set_ylabel(r'Distance to respective $w^*$')
    ax_main.set_title(f'Adversarial Attack Trial {trial_num} (R: {R_type})')
    ax_main.legend()
    
    # Check attack success
    attack_success = check_attack_success(interaction_result, agent_type=agent_type)
    
    # Add success indicator
    success_text = "SUCCESS" if attack_success else "FAILED"
    success_color = 'green' if attack_success else 'red'
    
    textstr = (f'Attack: {success_text}\n'
              f'Agent W final dist: {agent1_distances[-1]:.4f}\n'
              f'Agent U final dist: {agent2_distances[-1]:.4f}\n'
              f'R type: {R_type}\n'
              f'Angle: {interaction_result["angle_deg"]:.1f}°')
    
    # Add theoretical results if available
    if 'theoretical_results' in interaction_result:
        theoretical = interaction_result['theoretical_results']
        textstr += (f'\n--- Theorem 1 ---\n'
                   f'Agent W: {theoretical["plateau_error_agent1_boxed"]:.4f}\n'
                   f'Agent U: {theoretical["plateau_error_agent2_boxed"]:.4f}')
    
    props = dict(boxstyle='round', facecolor=success_color, alpha=0.3)
    ax_main.text(0.02, 0.98, textstr, transform=ax_main.transAxes, fontsize=23,
                verticalalignment='top', bbox=props)
    
    w_star1 = interaction_result['w_star1']
    w_star2 = interaction_result['w_star2']
    
    cosine_sim = torch.dot(w_star1, w_star2) / (torch.norm(w_star1) * torch.norm(w_star2))
    cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)
    angle_rad = torch.acos(cosine_sim).item()
    angle_deg = np.degrees(angle_rad)
    
    ax_circle = plt.subplot(1, 2, 2, projection='polar')
    
    theta = np.linspace(0, 2*np.pi, 100)
    r = np.ones_like(theta)
    ax_circle.plot(theta, r, 'k-', alpha=0.3, linewidth=1)
    
    ax_circle.plot([0, angle_rad], [0, 1], color=COLORS[3], marker='o', markersize=12, 
                  linewidth=4, label=f'Angle: {angle_deg:.1f}°')
    
    ax_circle.set_ylim(0, 1)
    ax_circle.set_title(r'Objective Alignment ($w_W^*$ vs $w_U^*$)', pad=20)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Adversarial plot saved: {save_path}")


def create_adversarial_statistics_plots(results_by_type: Dict, max_steps: int, plots_dir: str, learning_rate: float, agent_type: str = "model"):
    """
    Create statistical plots showing mean and standard deviation of adversarial attack results.
    
    Args:
        results_by_type: Dictionary with attack results grouped by R matrix type
        max_steps: Maximum number of interaction steps to plot
        plots_dir: Directory to save plots
        learning_rate: Learning rate used in experiments
        agent_type: Type of agents used ("model" or "openai") - affects filename prefix
    """
    plt.style.use('default')
    plt.rcParams.update({
        'font.size': 12,
        'font.family': 'serif',
        'axes.linewidth': 1.2,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'xtick.major.size': 4,
        'ytick.major.size': 4,
        'legend.frameon': False,
        'figure.dpi': 300
    })
    
    actual_max_steps = max_steps
    for R_type, results in results_by_type.items():
        if results is not None:
            actual_max_steps = min(actual_max_steps, len(results['agent1_mean']))
    
    steps = np.arange(1, actual_max_steps + 1)
    colors = {'orthogonal': COLORS[0], 'scaled': COLORS[1], 'opposite': COLORS[2]}
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))
    
    ax1 = axes[0]
    for R_type, results in results_by_type.items():
        if results is None:
            continue
        mean = results['agent1_mean'][:actual_max_steps]
        std = results['agent1_std'][:actual_max_steps]
        success_rate = results['success_rate']
        
        ax1.plot(steps, mean, color=colors[R_type], linewidth=3, 
                label=f'{R_type.title()}')
        ax1.fill_between(steps, mean - std, mean + std, 
                        color=colors[R_type], alpha=0.1)
    
    ax1.set_xlabel('Interaction Step')
    ax1.set_ylabel(r'Distance to $w^*$')
    ax1.set_title('Agent W (victim) error')
    ax1.legend()
    
    # Plot 2: Agent U (Attacker) distances  
    ax2 = axes[1]
    for R_type, results in results_by_type.items():
        if results is None:
            continue
        mean = results['agent2_mean'][:actual_max_steps]
        std = results['agent2_std'][:actual_max_steps]
        success_rate = results['success_rate']
        
        ax2.plot(steps, mean, color=colors[R_type], linewidth=3,
                label=f'{R_type.title()} (Attack Success Rate: {success_rate:.1%})')
        ax2.fill_between(steps, mean - std, mean + std,
                        color=colors[R_type], alpha=0.1)
    
    ax2.set_xlabel('Interaction Step')
    ax2.set_ylabel(r'Distance to $u^*$')
    ax2.set_title('Agent U (attacker) error')
    ax2.legend()
    
    plt.tight_layout()
    
    pub_dir = "../latex-paper/figs/publication_improved"
    png_dir = os.path.join(pub_dir, "png")
    os.makedirs(png_dir, exist_ok=True)
    agent_tag = "openai_" if agent_type == "openai" else ""
    save_path = os.path.join(png_dir, f"{agent_tag}adversarial_statistics_lr{learning_rate}.png")
    plt.savefig(save_path, format='png', bbox_inches='tight', dpi=300)
    
    print(f"  PNG: {save_path}")
    
    plt.close()


def create_adversarial_convergence_comparison_plot(results_by_type: Dict, max_steps: int, plots_dir: str, 
                                                 learning_rate: float):
    """
    Create a focused convergence comparison plot for adversarial attacks.
    
    Args:
        results_by_type: Dictionary with attack results grouped by R matrix type
        max_steps: Maximum number of interaction steps to plot
        plots_dir: Directory to save plots
        learning_rate: Learning rate used in experiments
    """
    colors = {'orthogonal': COLORS[0], 'scaled': COLORS[1], 'opposite': COLORS[2]}
    
    fig, ax = plt.subplots(1, 1, figsize=(16, 12))
    
    steps = np.arange(1, max_steps + 1)
    
    for R_type, results in results_by_type.items():
        if results is None:
            continue
            
        mean1 = results['agent1_mean']
        std1 = results['agent1_std']
        ax.plot(steps, mean1, color=colors[R_type], linewidth=3, linestyle='--',
               label=f'Agent W ({R_type})')
        ax.fill_between(steps, mean1 - std1, mean1 + std1,
                       color=colors[R_type], alpha=0.1)
        
        mean2 = results['agent2_mean']
        std2 = results['agent2_std']
        ax.plot(steps, mean2, color=colors[R_type], linewidth=3, linestyle='-',
               label=f'Agent U ({R_type})')
        ax.fill_between(steps, mean2 - std2, mean2 + std2,
                       color=colors[R_type], alpha=0.1)
    
    ax.set_xlabel('Interaction Step')
    ax.set_ylabel('Distance to respective optimal weights')
    ax.set_title('Adversarial Attack Convergence Comparison')
    ax.legend()
    
    plt.tight_layout()
    
    save_path = os.path.join(plots_dir, f"convergence_comparison_lr{learning_rate}.png")
    plt.savefig(save_path, format='png', bbox_inches='tight', dpi=300)
    print(f"Convergence comparison saved: {save_path}")
    
    plt.close()


def convert_tensors_to_lists(obj):
    """Recursively convert torch tensors to lists for JSON serialization."""
    if isinstance(obj, torch.Tensor):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_tensors_to_lists(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_tensors_to_lists(item) for item in obj]
    elif isinstance(obj, (int, float, str, bool)) or obj is None:
        return obj
    else:
        return str(obj)
