import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path

# Set style and font for publication-quality plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

# Unified style map to ensure consistency across all subplots
STYLE_MAP = {
    'CharFlip': {'color': "#f18275", 'linestyle': '--', 'marker': 'o', 'label': 'Character Reversal (Syntactic)'},
    'WordFlip': {'color': '#c0392b', 'linestyle': '-', 'marker': 's', 'label': 'Word Reversal (Syntactic)'},
    'Irrelevant': {'color': "#5dbaf7", 'linestyle': '--', 'marker': '^', 'label': 'Irrelevant (Semantic)'},
    'CounterFactual': {'color': '#2980b9', 'linestyle': '-', 'marker': 'D', 'label': 'CounterFactual (Semantic)'}
}


def load_and_process_data_part1(csv_files, filter_instruct_only=True, contamination_levels=None):
    """
    Loads data for Adherence and Accuracy plots (Plots A & B).
    """
    all_data = []
    for contamination_type, file_path in csv_files.items():
        df = pd.read_csv(file_path)
        df['contamination_pct'] = df['csv_name'].str.extract(r'_(\d+)_').astype(int)
        df['model_type'] = df['csv_name'].apply(lambda x: 'Instruct' if 'Instruct' in x else 'Base')
        df['accuracy_clean'] = df['accuracy'].str.replace('%', '').astype(float)
        df['variation_check_clean'] = df['variation_check'].str.replace('%', '').astype(float)
        df['contamination_type'] = contamination_type
        
        if filter_instruct_only:
            df = df[df['model_type'] == 'Instruct'].copy()
        
        if contamination_levels is not None:
            df = df[df['contamination_pct'].isin(contamination_levels)].copy()
            
        print(f"Part 1 ({contamination_type}): Using {len(df)} entries")
        if len(df) == 0: continue
            
        processed_df = df[['contamination_type', 'contamination_pct', 'accuracy_clean', 'variation_check_clean']].copy()
        all_data.append(processed_df)
    
    return pd.concat(all_data, ignore_index=True)

def calculate_summary_stats_part1(data, baseline_accuracy, baseline_std, baseline_n):
    """
    Calculates summary stats for Adherence and Accuracy, adding a baseline row.
    """
    summary = data.groupby(['contamination_type', 'contamination_pct']).agg({
        'accuracy_clean': ['mean', 'std', 'count'],
        'variation_check_clean': ['mean', 'std', 'count']
    }).reset_index()
    
    summary.columns = ['contamination_type', 'contamination_pct', 'accuracy_mean', 'accuracy_std', 'accuracy_count',
                      'variation_mean', 'variation_std', 'variation_count']
    
    summary['accuracy_se'] = summary['accuracy_std'] / np.sqrt(summary['accuracy_count'])
    summary['variation_se'] = summary['variation_std'] / np.sqrt(summary['variation_count'])
    
    baseline_se = 0.0 if baseline_n <= 1 else baseline_std / np.sqrt(baseline_n)
    baseline_rows = []
    for cont_type in summary['contamination_type'].unique():
        baseline_row = {
            'contamination_type': cont_type, 'contamination_pct': 0,
            'accuracy_mean': baseline_accuracy, 'accuracy_std': baseline_std, 'accuracy_count': baseline_n,
            'variation_mean': 0.0, 'variation_std': 0.0, 'variation_count': 1,
            'accuracy_se': baseline_se, 'variation_se': 0.0
        }
        baseline_rows.append(baseline_row)
    
    baseline_df = pd.DataFrame(baseline_rows)
    summary = pd.concat([summary, baseline_df], ignore_index=True)
    return summary.sort_values(['contamination_type', 'contamination_pct']).reset_index(drop=True)

def load_and_process_data_part2(csv_files, filter_instruct_only=True, noise_levels=None):
    """
    Loads data for Semantic Similarity and Grammatical Correctness plots (Plots C & D).
    """
    all_data = []
    for noise_type, file_path in csv_files.items():
        df = pd.read_csv(file_path)
        df['noise_level'] = df['csv_name'].str.extract(r'_(\d+)_').astype(int)
        df['model_type'] = df['csv_name'].apply(lambda x: 'Instruct' if 'Instruct' in x else 'Base')
        
        for metric in ['semantic_similarity', 'grammatical_correctness']:
            if metric in df.columns:
                numeric_values = pd.to_numeric(df[metric].astype(str).str.replace('%', ''), errors='coerce')
                df[f'{metric}_clean'] = numeric_values * 100 if metric == 'semantic_similarity' else numeric_values
            else:
                df[f'{metric}_clean'] = 0.0

        df['noise_type'] = noise_type
        if filter_instruct_only:
            df = df[df['model_type'] == 'Instruct'].copy()
        if noise_levels is not None:
            df = df[df['noise_level'].isin(noise_levels)].copy()
            
        print(f"Part 2 ({noise_type}): Using {len(df)} entries")
        if len(df) == 0: continue
            
        processed_df = df[['noise_type', 'noise_level', 'semantic_similarity_clean', 'grammatical_correctness_clean']].copy()
        all_data.append(processed_df)
    
    if not all_data: return pd.DataFrame()
    return pd.concat(all_data, ignore_index=True)

def calculate_summary_stats_part2(data):
    """
    Calculates summary stats for Similarity and Grammaticalness.
    """
    summary = data.groupby(['noise_type', 'noise_level']).agg({
        'semantic_similarity_clean': ['mean', 'std', 'count'],
        'grammatical_correctness_clean': ['mean', 'std', 'count']
    }).reset_index()
    
    summary.columns = ['noise_type', 'noise_level', 
                       'similarity_mean', 'similarity_std', 'similarity_count',
                       'grammaticalness_mean', 'grammaticalness_std', 'grammaticalness_count']
    
    summary['similarity_se'] = summary['similarity_std'] / np.sqrt(summary['similarity_count'])
    summary['grammaticalness_se'] = summary['grammaticalness_std'] / np.sqrt(summary['grammaticalness_count'])
            
    return summary.sort_values(['noise_type', 'noise_level']).reset_index(drop=True)


def create_combined_plot(summary_part1, summary_part2, save_path_base=None, figsize=(5.5, 4.5)):
    """
    Creates a single 2x2 plot combining all four metrics with a proper legend.
    """
    if summary_part1.empty or summary_part2.empty:
        print("One of the summary dataframes is empty. Cannot create plot.")
        return None

    fig, axes = plt.subplots(2, 2, figsize=figsize)
    fig.suptitle('Sensitivity of Small Language Models to Data Contamination', fontsize=10, fontweight='bold', y=1.00)

    # (I) Adherence to Data Contamination
    axA = axes[0, 0]
    axA.set_title('(A) Adherence to Data Contamination', fontsize=8, fontweight='bold')
    for cont_type in summary_part1['contamination_type'].unique():
        subset = summary_part1[summary_part1['contamination_type'] == cont_type]
        style = STYLE_MAP.get(cont_type, {})
        axA.plot(subset['contamination_pct'], subset['variation_mean'], color=style['color'], 
                 linestyle=style['linestyle'], marker=style['marker'], markersize=3, linewidth=1.5)
        axA.fill_between(subset['contamination_pct'], subset['variation_mean'] - subset['variation_se'], subset['variation_mean'] + subset['variation_se'], color=style['color'], alpha=0.15)
    axA.set_ylabel('Contamination Adherence (%)', fontsize=8, fontweight='bold')
    
    # (II) Task Accuracy
    axB = axes[0, 1]
    axB.set_title('(B) Task Accuracy', fontsize=8, fontweight='bold')
    for cont_type in summary_part1['contamination_type'].unique():
        subset = summary_part1[summary_part1['contamination_type'] == cont_type]
        style = STYLE_MAP.get(cont_type, {})
        axB.plot(subset['contamination_pct'], subset['accuracy_mean'], color=style['color'], 
                 linestyle=style['linestyle'], marker=style['marker'], markersize=3, linewidth=1.5)
        axB.fill_between(subset['contamination_pct'], subset['accuracy_mean'] - subset['accuracy_se'], subset['accuracy_mean'] + subset['accuracy_se'], color=style['color'], alpha=0.15)
    axB.set_ylabel('Task Accuracy (%)', fontsize=8, fontweight='bold')

    # (III) Semantic Similarity
    axC = axes[1, 0]
    axC.set_title('(C) Semantic Similarity', fontsize=8, fontweight='bold')
    for noise_type in summary_part2['noise_type'].unique():
        subset = summary_part2[summary_part2['noise_type'] == noise_type]
        style = STYLE_MAP.get(noise_type, {})
        axC.plot(subset['noise_level'], subset['similarity_mean'], color=style['color'], 
                 linestyle=style['linestyle'], marker=style['marker'], markersize=3, linewidth=1.5)
        axC.fill_between(subset['noise_level'], subset['similarity_mean'] - subset['similarity_se'], subset['similarity_mean'] + subset['similarity_se'], color=style.get('color'), alpha=0.15)
    axC.set_ylabel('Semantic Similarity (%)', fontsize=8, fontweight='bold')

    # (IV) Grammatical Correctness
    axD = axes[1, 1]
    axD.set_title('(D) Grammatical Correctness', fontsize=8, fontweight='bold')
    for noise_type in summary_part2['noise_type'].unique():
        subset = summary_part2[summary_part2['noise_type'] == noise_type]
        style = STYLE_MAP.get(noise_type, {})
        axD.plot(subset['noise_level'], subset['grammaticalness_mean'], color=style['color'], 
                 linestyle=style['linestyle'], marker=style['marker'], markersize=3, linewidth=1.5)
        axD.fill_between(subset['noise_level'], subset['grammaticalness_mean'] - subset['grammaticalness_se'], subset['grammaticalness_mean'] + subset['grammaticalness_se'], color=style.get('color'), alpha=0.15)
    axD.set_ylabel('Grammatical Correctness (%)', fontsize=8, fontweight='bold')


    all_axes = [axA, axB, axC, axD]
    contamination_ticks = sorted(summary_part1['contamination_pct'].unique())
    for ax in all_axes:
        ax.set_ylim(0, 105)
        ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)
        ax.set_xticks(contamination_ticks)
        for label in (ax.get_xticklabels() + ax.get_yticklabels()):
            label.set_fontsize(8)

    axC.set_xlabel('Contamination Percentage (%)', fontsize=8, fontweight='bold')
    axD.set_xlabel('Contamination Percentage (%)', fontsize=8, fontweight='bold')


    desired_order_keys = ['CharFlip', 'WordFlip', 'Irrelevant', 'CounterFactual']
    handles = []
    labels = []
    for key in desired_order_keys:
        style = STYLE_MAP[key]
        # Create a Line2D object for each legend entry
        handle = plt.Line2D([], [], color=style['color'], marker=style['marker'],
                            linestyle=style['linestyle'], label=style['label'])
        handles.append(handle)
        labels.append(style['label'])

    # Create the single figure-level legend at the bottom.
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.02), ncol=2, 
               fontsize=8, frameon=True, handlelength=2.5, columnspacing=1.5)

    plt.tight_layout()
    # Adjust top and bottom to make space for suptitle and legend
    fig.subplots_adjust(top=0.92, bottom=0.15)

    if save_path_base:
        png_path = f"{save_path_base}.png"
        pdf_path = f"{save_path_base}.pdf"
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        plt.savefig(pdf_path, bbox_inches='tight')
        print(f"\nCombined plot saved to {png_path} and {pdf_path}")
    
    plt.show()
    return fig


if __name__ == "__main__":

    csv_files = {
        'CharFlip': 'CharFlipped.csv',
        'WordFlip': 'WordFlipped.csv', 
        'Irrelevant': 'Irrelevant.csv',
        'CounterFactual': 'CounterFactual.csv'
    }
    baseline_csv_path = 'Instruct.csv'

    print("Checking file existence...")
    all_files_to_check = list(csv_files.values()) + [baseline_csv_path]
    if not all(Path(p).exists() for p in all_files_to_check):
        print("\nERROR: One or more required CSV files are missing.")
        for p in all_files_to_check:
             if not Path(p).exists(): print(f"  - Missing: {p}")
        exit(1)
    print("All required files found.")

    print("\n--- Processing Data for Plots (A) Adherence & (B) Accuracy ---")
    
    baseline_df_part1 = pd.read_csv(baseline_csv_path)
    accuracy_clean = baseline_df_part1['accuracy'].str.replace('%', '').astype(float)
    calculated_baseline_accuracy = accuracy_clean.mean()
    calculated_baseline_std = accuracy_clean.std(ddof=0)
    calculated_baseline_n = len(accuracy_clean)
    print(f"Calculated baseline for Part 1: Mean={calculated_baseline_accuracy:.2f}%, StdDev={calculated_baseline_std:.2f}%, N={calculated_baseline_n}")

    data_part1 = load_and_process_data_part1(csv_files, contamination_levels=[25, 50, 75, 100])
    summary_part1 = calculate_summary_stats_part1(data_part1, 
                                                  baseline_accuracy=calculated_baseline_accuracy, 
                                                  baseline_std=calculated_baseline_std,
                                                  baseline_n=calculated_baseline_n)

    print("\n--- Processing Data for Plots (C) Similarity & (D) Grammaticalness ---")
    
    noisy_data_part2 = load_and_process_data_part2(csv_files, noise_levels=[25, 50, 75, 100])
    
    baseline_df_part2 = pd.read_csv(baseline_csv_path)
    for metric in ['semantic_similarity', 'grammatical_correctness']:
        if metric in baseline_df_part2.columns:
             cleaned_values = pd.to_numeric(baseline_df_part2[metric].astype(str).str.replace('%', ''), errors='coerce')
             baseline_df_part2[f'{metric}_clean'] = cleaned_values * 100 if metric == 'semantic_similarity' else cleaned_values
        else:
             baseline_df_part2[f'{metric}_clean'] = 0.0

    baseline_df_part2['noise_level'] = 0
    all_baseline_rows = []
    unique_noise_types = noisy_data_part2['noise_type'].unique()
    for n_type in unique_noise_types:
        temp_df = baseline_df_part2[['semantic_similarity_clean', 'grammatical_correctness_clean', 'noise_level']].copy()
        temp_df['noise_type'] = n_type
        all_baseline_rows.append(temp_df)
    
    baseline_df_part2_processed = pd.concat(all_baseline_rows, ignore_index=True)
    print(f"Processed baseline data for {len(unique_noise_types)} noise types for Part 2.")

    combined_data_part2 = pd.concat([noisy_data_part2, baseline_df_part2_processed], ignore_index=True)
    summary_part2 = calculate_summary_stats_part2(combined_data_part2)

    print("\n--- Creating Combined Plot ---")
    
    if data_part1.empty or combined_data_part2.empty:
        print("\nERROR: No data was loaded after processing. Cannot create plot.")
        exit(1)

    create_combined_plot(summary_part1, summary_part2, save_path_base='figure2')
    
    print("\nDone!")