import pandas as pd
import matplotlib.pyplot as plt
import argparse
import os

def plot_comparison(csv1_path, csv2_path, label1, label2, output_dir):
    # Load data
    df1 = pd.read_csv(csv1_path)
    df2 = pd.read_csv(csv2_path)

    os.makedirs(output_dir, exist_ok=True)

    # Create a single figure with two subplots (2 rows, 1 column)
    fig, (ax_prob, ax_lp) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

    # --- Subplot 1: Probability and Count ---
    # Probabilities
    ax_prob.plot(df1['position'], df1['mean_prob'], label=f'{label1} Prob', color='tab:blue', linewidth=2)
    ax_prob.plot(df2['position'], df2['mean_prob'], label=f'{label2} Prob', color='tab:red', linewidth=2, linestyle='--')
    ax_prob.set_ylabel('Mean Probability', color='black')
    ax_prob.tick_params(axis='y', labelcolor='black')
    ax_prob.grid(True, alpha=0.3)
    ax_prob.set_title('Token Probability Comparison')

    # Counts on secondary axis
    ax_prob_count = ax_prob.twinx()
    ax_prob_count.plot(df1['position'], df1['count'], label=f'{label1} Count', color='tab:blue', alpha=0.3, linestyle=':')
    ax_prob_count.plot(df2['position'], df2['count'], label=f'{label2} Count', color='tab:red', alpha=0.3, linestyle=':')
    ax_prob_count.set_ylabel('Sample Count', color='gray')
    ax_prob_count.tick_params(axis='y', labelcolor='gray')

    # Combined Legend for subplot 1
    lines_p, labels_p = ax_prob.get_legend_handles_labels()
    lines_pc, labels_pc = ax_prob_count.get_legend_handles_labels()
    ax_prob.legend(lines_p + lines_pc, labels_p + labels_pc, loc='upper right', fontsize='small')

    # --- Subplot 2: Log-Probability and Count ---
    # Log-Probabilities
    ax_lp.plot(df1['position'], df1['mean_logprob'], label=f'{label1} Log-Prob', color='tab:blue', linewidth=2)
    ax_lp.plot(df2['position'], df2['mean_logprob'], label=f'{label2} Log-Prob', color='tab:red', linewidth=2, linestyle='--')
    ax_lp.set_xlabel('Token Position')
    ax_lp.set_ylabel('Mean Log-Probability', color='black')
    ax_lp.tick_params(axis='y', labelcolor='black')
    ax_lp.grid(True, alpha=0.3)
    ax_lp.set_title('Token Log-Probability Comparison')

    # Counts on secondary axis
    ax_lp_count = ax_lp.twinx()
    ax_lp_count.plot(df1['position'], df1['count'], label=f'{label1} Count', color='tab:blue', alpha=0.3, linestyle=':')
    ax_lp_count.plot(df2['position'], df2['count'], label=f'{label2} Count', color='tab:red', alpha=0.3, linestyle=':')
    ax_lp_count.set_ylabel('Sample Count', color='gray')
    ax_lp_count.tick_params(axis='y', labelcolor='gray')

    # Combined Legend for subplot 2
    lines_lp, labels_lp = ax_lp.get_legend_handles_labels()
    lines_lpc, labels_lpc = ax_lp_count.get_legend_handles_labels()
    ax_lp.legend(lines_lp + lines_lpc, labels_lp + labels_lpc, loc='upper right', fontsize='small')

    plt.tight_layout()
    output_path = os.path.join(output_dir, 'comparison_combined.png')
    plt.savefig(output_path, dpi=300)
    print(f"Saved combined comparison plot to {output_path}")
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Compare token probabilities and counts from two models in one figure.')
    parser.add_argument('--csv1', type=str, required=True, help='Path to first CSV file')
    parser.add_argument('--csv2', type=str, required=True, help='Path to second CSV file')
    parser.add_argument('--label1', type=str, default='Model 1', help='Label for first model')
    parser.add_argument('--label2', type=str, default='Model 2', help='Label for second model')
    parser.add_argument('--output_dir', type=str, default='eval_scripts/analysis/plots', help='Directory to save plots')
    
    args = parser.parse_args()
    plot_comparison(args.csv1, args.csv2, args.label1, args.label2, args.output_dir)
