import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import seaborn as sns
from pathlib import Path
from sklearn.calibration import calibration_curve, CalibratedClassifierCV
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import brier_score_loss

# Set global matplotlib parameters for cleaner, more professional plots with much larger text
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif'],
    'font.size': 36,                # 3x larger base font
    'axes.titlesize': 42,           # 3x larger title
    'axes.labelsize': 36,           # 3x larger axis labels
    'xtick.labelsize': 33,          # 3x larger tick labels
    'ytick.labelsize': 33,          # 3x larger tick labels
    'legend.fontsize': 33,          # 3x larger legend text
    'figure.titlesize': 48,         # 3x larger figure title
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'lines.linewidth': 4,           # Thicker lines to match larger text
    'lines.markersize': 18,         # Larger markers to match larger text
    'axes.linewidth': 3             # Thicker axis lines
})

EXPERIMENT_FOLDER_NAME = "experiment_results"
ANALYSIS_OUTPUT_FOLDER_NAME = "experiment_analyses"
ANALYSIS_NAME = "calibration_plots"

def format_metric_name(metric_name):
    """Convert metric names to properly formatted display names."""
    name_mapping = {
        'binoculars_score': 'Binoculars Score',
        'telescope_perplexity': 'Telescope Perplexity',
        'telescope_score': 'Telescope Score'
    }
    return name_mapping.get(metric_name, metric_name.replace('_', ' ').title())

def create_output_folders(datasets):
    """Create output directory structure."""
    # Create main output directory
    os.makedirs(f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}", exist_ok=True)
    
    # Create combined folder
    os.makedirs(f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}/combined", exist_ok=True)
    
    # Create dataset-specific folders
    for dataset in datasets:
        os.makedirs(f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}/{dataset}", exist_ok=True)

def load_and_combine_data(model_name, dataset_name):
    """Load data from a specific model and dataset combination."""
    file_path = f"{EXPERIMENT_FOLDER_NAME}/{model_name}_{dataset_name}_dataset/raw_data.csv"
    print(file_path)
    df = pd.read_csv(file_path)
    df['dataset'] = dataset_name  # Add dataset name to the dataframe
    return df

def create_calibration_plot(data, metric, model_name, output_path, dataset_name=None, n_bins=10):
    """Create and save calibration plots for a specific metric, with both normal and isotonic regression."""
    # Create figure with room for legend below plot
    fig, ax = plt.subplots(figsize=(16, 16))
    plt.subplots_adjust(bottom=0.25)  # Make room for legend below the plot
    plt.style.use('seaborn-v0_8-whitegrid')  # Clean, professional style
    
    # Extract y_true (actual labels) and y_prob (predicted probabilities from the metric)
    y_true = data['y_labels'].values
    
    # Normalize the metric to [0,1] range
    metric_values = data[metric].values
    
    # Min-max normalization
    min_val = np.min(metric_values)
    max_val = np.max(metric_values)
    
    # Only normalize if there's a range
    if max_val > min_val:
        normalized_values = (metric_values - min_val) / (max_val - min_val)
    else:
        normalized_values = np.zeros_like(metric_values)
    
    # Determine which direction is better by checking which gives a higher mean for AI texts
    ai_mean = np.mean(normalized_values[y_true == 1])
    human_mean = np.mean(normalized_values[y_true == 0])
    
    # If AI mean is lower than human mean, flip the scores
    # We want higher scores to indicate AI text for consistency
    if ai_mean < human_mean:
        y_prob = 1 - normalized_values
        direction = "flipped"
    else:
        y_prob = normalized_values
        direction = "original"
        
    # Plot the perfectly calibrated line
    ax.plot([0, 1], [0, 1], linestyle='--', color='gray', linewidth=3.5, alpha=0.7, label='Perfectly calibrated')
    
    # Compute calibration curves
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)
    
    # Plot the original calibration curve with standard blue color
    ax.plot(prob_pred, prob_true, marker='o', markersize=14, linewidth=5, 
             color='#1f77b4',  # Standard blue
             label=f'Original (Brier: {brier_score_loss(y_true, y_prob):.3f})')
    
    # Apply isotonic regression for recalibration
    ir = IsotonicRegression(out_of_bounds='clip')
    ir.fit(y_prob, y_true)
    y_prob_isotonic = ir.predict(y_prob)
    
    # Compute calibration curve after isotonic regression
    prob_true_isotonic, prob_pred_isotonic = calibration_curve(y_true, y_prob_isotonic, n_bins=n_bins)
    
    # Plot the isotonic calibration curve with standard orange color
    ax.plot(prob_pred_isotonic, prob_true_isotonic, marker='s', markersize=14, linewidth=5,
             color='#ff7f0e',  # Standard orange
             label=f'Isotonic (Brier: {brier_score_loss(y_true, y_prob_isotonic):.3f})')
    
    # Set title and labels with cleaner styling - use formatted metric name
    formatted_metric = format_metric_name(metric)
    if dataset_name:
        title = f'{formatted_metric} Calibration - {model_name} - {dataset_name}'
    else:
        title = f'{formatted_metric} Calibration - {model_name}'
        
    ax.set_title(title, fontsize=42, pad=20)
    ax.set_xlabel('Predicted Probability', labelpad=20)
    ax.set_ylabel('True Probability (Fraction of Positives)', labelpad=20)
    
    # Create a legend outside and below the plot area
    legend = ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15),
                      frameon=True, framealpha=0.9, borderpad=1.2,
                      fontsize=33, ncol=3)  # Use 3 columns for a horizontal layout
    legend.get_frame().set_linewidth(3)
    
    # Set equal aspect ratio for the plot
    ax.set_aspect('equal', adjustable='box')
    
    # Set axis limits
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    
    # Add light grid with thicker lines
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=2)
    
    # Increase tick width and length
    ax.tick_params(width=3, length=10, pad=10)
    
    # Save with higher DPI for better quality
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def save_calibration_stats(data, metric, model_name, output_path, dataset_name=None):
    """Save detailed calibration statistics to a text file."""
    y_true = data['y_labels'].values
    metric_values = data[metric].values
    
    # Normalize and determine direction
    min_val = np.min(metric_values)
    max_val = np.max(metric_values)
    
    # Only normalize if there's a range
    if max_val > min_val:
        normalized_values = (metric_values - min_val) / (max_val - min_val)
    else:
        normalized_values = np.zeros_like(metric_values)
    
    # Determine which direction is better
    ai_mean = np.mean(normalized_values[y_true == 1])
    human_mean = np.mean(normalized_values[y_true == 0])
    
    # If AI mean is lower than human mean, flip the scores
    if ai_mean < human_mean:
        y_prob = 1 - normalized_values
        direction = "flipped"
    else:
        y_prob = normalized_values
        direction = "original"
    
    # Apply isotonic regression
    ir = IsotonicRegression(out_of_bounds='clip')
    ir.fit(y_prob, y_true)
    y_prob_isotonic = ir.predict(y_prob)
    
    # Calculate Brier scores
    brier_original = brier_score_loss(y_true, y_prob)
    brier_isotonic = brier_score_loss(y_true, y_prob_isotonic)
    
    # Use formatted metric name in stats file
    formatted_metric = format_metric_name(metric)
    
    # Save statistics
    with open(output_path, 'w') as f:
        f.write(f"Calibration statistics for {model_name} - {formatted_metric}")
        if dataset_name:
            f.write(f" - {dataset_name}")
        f.write("\n\n")
        
        f.write(f"Score direction: {direction}\n")
        if direction == "flipped":
            f.write("(Lower original values indicate AI text)\n\n")
        else:
            f.write("(Higher original values indicate AI text)\n\n")
            
        f.write(f"Number of samples: {len(y_true)}\n")
        f.write(f"Number of positives (AI): {np.sum(y_true)}\n")
        f.write(f"Number of negatives (human): {len(y_true) - np.sum(y_true)}\n\n")
        
        f.write(f"Original Brier score: {brier_original:.4f}\n")
        f.write(f"Isotonic calibration Brier score: {brier_isotonic:.4f}\n")
        f.write(f"Improvement: {(brier_original - brier_isotonic) / brier_original * 100:.2f}%\n\n")
        
        # Add means for each class
        f.write(f"AI mean (original metric): {np.mean(metric_values[y_true == 1]):.4f}\n")
        f.write(f"Human mean (original metric): {np.mean(metric_values[y_true == 0]):.4f}\n")
        f.write(f"Difference: {np.mean(metric_values[y_true == 1]) - np.mean(metric_values[y_true == 0]):.4f}\n\n")
        
        # Add percentile information for the raw metric
        f.write("Metric percentiles:\n")
        percentiles = [0, 10, 25, 50, 75, 90, 100]
        
        f.write("Overall:\n")
        for p in percentiles:
            f.write(f"{p}th percentile: {np.percentile(metric_values, p):.4f}\n")
        
        f.write("\nHuman texts:\n")
        human_values = metric_values[y_true == 0]
        for p in percentiles:
            f.write(f"{p}th percentile: {np.percentile(human_values, p):.4f}\n")
        
        f.write("\nAI texts:\n")
        ai_values = metric_values[y_true == 1]
        for p in percentiles:
            f.write(f"{p}th percentile: {np.percentile(ai_values, p):.4f}\n")

def main():
    # Define models and their features
    model_features = {
        "falcon_7B": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "gemma2_9B": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "smollm_360M": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "smollm_135M": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "smollm_1_7B": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "smollm2_360M": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "smollm2_135M": ["binoculars_score", "telescope_perplexity", "telescope_score"],
        "smollm2_1_7B": ["binoculars_score", "telescope_perplexity", "telescope_score"],
    }
    
    datasets = [
        "essay",
        "ai_human",
        "hc3",
        "hc3_plus",
        "custom4o"
    ]
    
    # Create folder structure
    create_output_folders(datasets)
    
    # Process each model
    for model_name, features in model_features.items():
        # First create per-dataset plots
        for dataset in datasets:
            try:
                df = load_and_combine_data(model_name, dataset)

                # Create plots for each feature
                for feature in features:
                    output_path = f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}/{dataset}/{model_name}_{feature}_calibration.png"
                    create_calibration_plot(df, feature, model_name, output_path, dataset)
                    print(f"Created calibration plot for {model_name} - {feature} - {dataset}")
                    
                    # Save calibration statistics to text file
                    stats_path = f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}/{dataset}/{model_name}_{feature}_calibration_stats.txt"
                    save_calibration_stats(df, feature, model_name, stats_path, dataset)
                    
            except FileNotFoundError:
                print(f"Warning: No data found for {model_name} on {dataset}")
                continue
        
        # Then create combined plots
        try:
            # Combine data from all datasets for this model
            model_data = []
            for dataset in datasets:
                try:
                    df = load_and_combine_data(model_name, dataset)
                    model_data.append(df)
                except FileNotFoundError:
                    continue
            
            if model_data:
                combined_data = pd.concat(model_data, ignore_index=True)
                
                # Create combined plots for each feature
                for feature in features:
                    output_path = f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}/combined/{model_name}_{feature}_calibration.png"
                    create_calibration_plot(combined_data, feature, model_name, output_path)
                    print(f"Created combined calibration plot for {model_name} - {feature}")
                    
                    # Save combined calibration statistics to text file
                    stats_path = f"{ANALYSIS_OUTPUT_FOLDER_NAME}/{ANALYSIS_NAME}/combined/{model_name}_{feature}_calibration_stats.txt"
                    save_calibration_stats(combined_data, feature, model_name, stats_path)
        
        except Exception as e:
            print(f"Error processing combined data for {model_name}: {str(e)}")

if __name__ == "__main__":
    main()