import argparse
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os


def load_linear_trend_data(dataset, results_dir):
    """Load linear-trend CSVs from the new file layout."""
    models = ["CAM", "SCORE", "DAS", "NoGAM", "DiffAN", "CaPS", "OURS"]
    all_data = []
    
    for model in models:
        filename = f"{results_dir}/linear_{dataset}_{model}.csv"
        
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            if not df.empty:
                df['model'] = model
                all_data.append(df)
                print(f"Loaded: {filename}")
        else:
            print(f"Missing: {filename}")
    
    if not all_data:
        raise ValueError(f"No linear trend data found for {dataset}")
    
    combined_df = pd.concat(all_data, ignore_index=True)
    return combined_df

def plot_trends(args):
    results_dir = args.results_dir
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    try:
        df = load_linear_trend_data(args.dataset, results_dir)
    except ValueError as e:
        print(f"Error: {e}")
        return

    # Sort data for proper line plotting
    df = df.sort_values(by=['model', 'linear_rate'])

    # --- Plotting Configuration ---
    metrics = ['shd', 'sid', 'precision', 'recall', 'F1']
    metric_arrows = {'shd': r'$\downarrow$', 'sid': r'$\downarrow$', 'precision': r'$\uparrow$', 'recall': r'$\uparrow$', 'F1': r'$\uparrow$'}
    metric_display_names = {
        'shd': 'SHD',
        'sid': 'SID',
        'precision': 'Precision',
        'recall': 'Recall',
        'F1': 'F1'
    }
    
    # Rename OURS to P^2 and set it as the first model for plotting
    df['model'] = df['model'].replace({'OURS': 'CaPS w/ PEP'})
    model_order = ["CAM", "SCORE", "DAS", "NoGAM", "DiffAN", "CaPS", "CaPS w/ PEP"]
    
    markers = ['o', 'o', 'o', 'o', 'o', 'o', 'o']
    
    # Use the same color mapping as plot_box_model.py
    colors_2 = sns.color_palette("tab20c", 20)
    colors = sns.color_palette("Paired", 12)
    color_map = {
        "CAM": colors_2[17],
        "SCORE": colors[6], 
        "DAS": colors[0],
        "NoGAM": colors[2],
        "DiffAN": colors[8],
        "CaPS": colors[4],
        'CaPS w/ PEP': colors[5]  
    }

    sns.set_theme(style="whitegrid")
    # Manually control layout instead of using autolayout
    plt.rcParams.update({'font.size': 14})
    
    fig, axes = plt.subplots(1, 5, figsize=(25, 4))
    # suptitle will be placed within the top margin created by tight_layout
    # fig.suptitle(f'Performance Trends on {args.dataset} Dataset vs Linear Rate', fontsize=20)

    p2_data = df[df['model'] == 'CaPS w/ PEP']
    baseline_data = df[df['model'] != 'CaPS w/ PEP']

    for i, metric in enumerate(metrics):
        ax = axes[i]
        
        # Plot baseline models first
        baseline_models = [m for m in model_order if m != 'CaPS w/ PEP']
        baseline_palette = [color_map[m] for m in baseline_models]
        
        sns.lineplot(
            data=baseline_data, 
            x='linear_rate', y=metric, hue='model', style='model',
            hue_order=baseline_models,
            style_order=baseline_models,
            palette=baseline_palette,
            markers=markers[1:], # Start markers from the second one
            dashes=False, markersize=10, ax=ax, zorder=5, errorbar='sd'
        )
        
        # Plot P^2 on top
        if not p2_data.empty:
            sns.lineplot(
                data=p2_data,
                x='linear_rate', y=metric,
                color=color_map['CaPS w/ PEP'], # Use consistent color mapping
                marker=markers[0], # Use the first marker
                label=r'CaPS w/ PEP',
                dashes=False, markersize=10, ax=ax, zorder=5, errorbar='sd' # Higher zorder to bring to front
            )
        
        arrow = metric_arrows.get(metric, '')
        display_name = metric_display_names.get(metric, metric.upper())
        
        ax.set_title('') # Remove individual subplot titles
        
        ax.set_xlabel('Linear proportion', fontsize=16, fontweight='normal')
        ax.set_ylabel(f'{display_name} {arrow}', fontsize=16, fontweight='normal')
        ax.grid(True)
        ax.tick_params(axis='x', labelsize=14)
        ax.tick_params(axis='y', labelsize=14)
        if ax.get_legend() is not None:
            ax.get_legend().remove()
    
    # Create a single legend for the entire figure at the top
    handles, labels = axes[0].get_legend_handles_labels()
    
    if not p2_data.empty:
        if 'CaPS w/ PEP' not in labels:
            # Create a custom handle for P^2
            from matplotlib.lines import Line2D
            p2_handle = Line2D([0], [0], marker=markers[0], color='red', label='CaPS w/ PEP', markersize=10, linestyle='None')  
            handles.append(p2_handle)
            labels.append('CaPS w/ PEP')

    try:
        label_order_map = {label: i for i, label in enumerate(model_order)}
        sorted_handles_labels = sorted(zip(handles, labels), key=lambda x: label_order_map.get(x[1], 99))
        handles, labels = zip(*sorted_handles_labels)
    except Exception:
        pass

    # Place legend above subplots. loc='lower center' places the bottom of the legend at the anchor point.
    # fig.legend(handles, labels, title=None, loc='center right', bbox_to_anchor=(1, 0.5), ncol=1, fontsize=12) # Set ncol=1 for vertical alignment

    # To place the legend at the top center:
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.04), ncol=len(model_order), fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.9])

    # Adjust subplot parameters for a tight layout
    # The `rect` parameter [left, bottom, right, top] reserves space for elements like a legend.
    # Setting top < 1.0 leaves space at the top.
    
    # Adjust for the suptitle and legend
    # fig.subplots_adjust(right=0.92,top=0.85)
    
    # Save the figure
    output_path = os.path.join(output_dir, f"{args.dataset}_linear_trends.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')

    print(f"Trend plot saved to {output_path}")
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot linear rate performance trends.")
    parser.add_argument(
        "--results_dir",
        type=str,
        default="./experiments/plot_data/linear",
        help="Directory containing CSV files",
    )
    parser.add_argument(
        "--dataset", type=str, required=True, help="Dataset name (e.g., SynER1)"
    )
    parser.add_argument(
        "--output_dir", type=str, default="fig/linear", help="Output directory for plots"
    )

    args = parser.parse_args()
    plot_trends(args)
