import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset

# Visualization settings
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = [16, 10]
plt.rcParams['font.size'] = 11

def plot_unbalance_ci(user='user1', model='SimpleCNNBinClass', unbalance_coefs=[2, 5, 30, 40, 50], optimizers_list=None, split='test', csv_suffix=''):
    """
    Plots confidence intervals for different unbalance_coef values
    
    Args:
        user: user name in path
        model: model name (e.g., 'SimpleCNNBinClass' or 'ResNet18_32x32BinClass')
        unbalance_coefs: list of k values (unbalance_coef)
        optimizers_list: list of optimizers to plot (if None, then all)
        split: data type to plot ('test' or 'val')
        csv_suffix: suffix for CSV filename (e.g., '_recalculated_40seeds')
    """
    if model == 'swin_tiny': 
        base_path = Path(f'./tuning/{user}/unbalanced_tiny_imagenet')
    else:
        base_path = Path(f'./tuning/{user}/unbalanced_cifar10')
    
    # Convert model name to lowercase for path
    model_lower = model.lower()
    
    # Collect data from all CSV files
    all_data = []
    
    for k in unbalance_coefs:
        # Form filename with suffix if specified
        csv_filename = f'final_scores_with_CI{csv_suffix}.csv'
        csv_path = base_path / f'{model_lower}_{k}' / 'f1_results' / csv_filename
        
        if not csv_path.exists():
            print(f"Warning: file {csv_path} not found, skipping k={k}")
            continue
            
        df = pd.read_csv(csv_path)
        
        # Filter by selected split
        df_split = df[df['split'] == split].copy()
        if df_split.empty:
            print(f"Warning: no data for split={split} in file {csv_path}, skipping k={k}")
            continue
        df_split['unbalance_coef'] = k
        
        all_data.append(df_split)
    
    if not all_data:
        print("Error: no data files found")
        return
    
    # Combine all data
    combined_df = pd.concat(all_data, ignore_index=True)
    
    # Get list of all optimizers
    all_optimizers = sorted(combined_df['optimizer'].unique())
    
    # Filter by optimizer list if specified
    if optimizers_list is not None:
        optimizers = [opt for opt in all_optimizers if opt in optimizers_list]
        if not optimizers:
            print(f"Error: none of the specified optimizers found in data")
            print(f"Available optimizers: {all_optimizers}")
            return
        # Filter data by selected optimizers
        combined_df = combined_df[combined_df['optimizer'].isin(optimizers)]
    else:
        optimizers = all_optimizers
    
    print(f"Found optimizers: {len(optimizers)}")
    print(f"Optimizers: {optimizers}")
    
    # Dictionary for renaming optimizers in legend
    optimizer_names = {
        'AdamWBetas': 'AdamW',
        'SGD': 'SGD',
        'SGDLinearLR': 'SGD with scheduler',
        'Signum': 'Signum',
        'Signum+SGD_not_decoupled_wd': 'HardSwitchSign',
        'SoftSignumSGD_not_decoupled_wd': 'SoftSignum',
    }
    
    import seaborn as sns
    sns.set_theme(style="whitegrid", context="talk", font_scale=1.3)
    
    # Create plot
    fig, ax = plt.subplots(figsize=(16, 8))
    
    # Color palette and dictionary for assigning colors to optimizers
    colors = sns.color_palette("bright")
    optimizer_color = {
        'SoftSignum': colors[4], 
        'Signum': colors[1], 
        'SGD with scheduler': colors[2], 
        'HardSwitchSign': colors[6]
    }
    
    # Assign colors to optimizers (use renamed names)
    optimizer_to_color = {}
    for optimizer in optimizers:
        legend_name = optimizer_names.get(optimizer, optimizer)
        # Try to find color by renamed name
        if legend_name in optimizer_color:
            optimizer_to_color[optimizer] = optimizer_color[legend_name]
        else:
            # If not in dictionary, use default color
            optimizer_to_color[optimizer] = colors[len(optimizer_to_color) % len(colors)]
    
    # Save curves to draw the same set in insets later
    curves = {}
    curves_ci = {}

    # Plot graph for each optimizer
    for idx, optimizer in enumerate(optimizers):
        opt_data = combined_df[combined_df['optimizer'] == optimizer].copy()
        opt_data = opt_data.sort_values('unbalance_coef')
        
        k_values = opt_data['unbalance_coef'].values
        means = opt_data['mean'].values
        ci_lower = opt_data['ci_95%_lower'].values
        ci_upper = opt_data['ci_95%_upper'].values

        # Use color from dictionary
        color = optimizer_to_color[optimizer]
        curves[optimizer] = (k_values, means, color)
        curves_ci[optimizer] = (k_values, ci_lower, ci_upper, color)
        
        # Draw confidence intervals (filled area)
        ax.fill_between(k_values, ci_lower, ci_upper, 
                        alpha=0.2, color=color, 
                        label='_nolegend_')
        
        # Draw mean values with '^' marker
        ax.plot(k_values, means, marker='^', markersize=12, 
               linewidth=3, color=color, 
               label=f'{optimizer} (mean)', linestyle='-')
    
    ax.set_xlabel('Unbalance Coefficient (k)', fontsize=25, fontweight='bold')
    ax.set_ylabel('F1-Score', fontsize=25, fontweight='bold')
    ax.grid(True, alpha=0.3)
    # Increase size of numbers on axes
    ax.tick_params(labelsize=30)
    
    # Create inset for zoom
    y_main_min, y_main_max = ax.get_ylim()
    zoom_x_min, zoom_x_max = 1, 10
    zoom_y_min, zoom_y_max = 55, y_main_max

    # Size of enlarged plot is set here (width and height)
    ax_inset = inset_axes(
        ax,
        width="55%",
        height="55%",
        loc="upper right",
        borderpad=1.0,
    )
    ax_inset.set_facecolor((1, 1, 1, 0.95))
    
    # Add solid border for inset
    for spine in ax_inset.spines.values():
        spine.set_edgecolor('black')
        spine.set_linewidth(2)

    # Draw the same curves in inset (and CI too), but only for k from 1 to 10
    for optimizer in optimizers:
        k_vals, mean_vals, color = curves[optimizer]
        _, lo, hi, _color_ci = curves_ci[optimizer]
        
        # Filter only data in zoom range
        mask = (k_vals >= zoom_x_min) & (k_vals <= zoom_x_max)
        k_vals_zoom = k_vals[mask]
        mean_vals_zoom = mean_vals[mask]
        lo_zoom = lo[mask]
        hi_zoom = hi[mask]
        
        if len(k_vals_zoom) > 0:
            ax_inset.fill_between(k_vals_zoom, lo_zoom, hi_zoom, alpha=0.2, color=color)
            # Use renamed name from dictionary
            legend_name = optimizer_names.get(optimizer, optimizer)
            ax_inset.plot(k_vals_zoom, mean_vals_zoom, marker='^', markersize=10, linewidth=3.0, color=color, 
                         linestyle='-', label=legend_name)

    ax_inset.set_xlim(zoom_x_min - 0.5, zoom_x_max + 0.5)
    ax_inset.set_ylim(zoom_y_min, zoom_y_max)
    # Remove axis labels on enlarged plot
    # ax_inset.set_title('Zoom: k ∈ [1, 10]', fontsize=10, fontweight='bold')
    # Apply the same font size settings for numbers on axes
    ax_inset.tick_params(labelsize=30)
    ax_inset.grid(True, alpha=0.3)
    # Legend is moved to enlarged plot (font size is adjusted here)
    legend = ax_inset.legend(loc='best', fontsize=18)
    # Thicken lines in legend
    for line in legend.get_lines():
        line.set_linewidth(4.0)
    
    # Draw connecting lines from rectangle to inset (thicken dashed line)
    mark_inset(ax, ax_inset, loc1=1, loc2=3, fc="none", ec="red", linestyle='--', alpha=0.6, linewidth=4.0)
    
    plt.tight_layout()
    
    # Save plot
    output_path = base_path / f'{model_lower}_{split}_{csv_suffix}.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\nPlot saved: {output_path}")

    output_path_pdf = base_path / f'{model_lower}_{split}_{csv_suffix}.pdf'
    plt.savefig(output_path_pdf, bbox_inches='tight')
    print(f"PDF saved: {output_path_pdf}")
    
    plt.show()


if __name__ == '__main__':
    import argparse
    
    parser = argparse.ArgumentParser(description='Plot confidence intervals for different unbalance coefficients')
    parser.add_argument('--user', type=str, default='user1', 
                       help='User folder name')
    parser.add_argument('--model', type=str, default='SimpleCNNBinClass',
                       choices=['SimpleCNNBinClass', 'ResNet18_32x32BinClass', 'swin_tiny'],
                       help='Model name')
    parser.add_argument('--unbalance_coefs', type=int, nargs='+', 
                       default=[2, 5, 30, 40, 50],
                       help='List of unbalance coefficients (k values)')
    parser.add_argument('--optimizers', type=str, nargs='+', default=None,
                       help='List of optimizers to plot (if not specified, all optimizers will be plotted)')
    parser.add_argument('--split', type=str, default='test', choices=['test', 'val'],
                       help='Data split to plot (test or val)')
    parser.add_argument('--csv_suffix', type=str, default='',
                       help='Suffix for CSV filename (e.g., "_recalculated_40seeds" for final_scores_with_CI_recalculated_40seeds.csv)')
    
    args = parser.parse_args()
    
    plot_unbalance_ci(user=args.user, model=args.model, unbalance_coefs=args.unbalance_coefs, 
                     optimizers_list=args.optimizers, split=args.split, csv_suffix=args.csv_suffix)

