import pandas as pd
import numpy as np
import os
import glob
from sklearn.metrics import roc_auc_score, accuracy_score
import argparse

def calculate_metrics_for_file(file_path):
    """
    Calculate AUC and Accuracy metrics for all prediction methods in a single CSV file
    
    Parameters:
    file_path: CSV file path
    
    Returns:
    DataFrame containing metric results
    """
    # Read CSV file
    df = pd.read_csv(file_path)
    
    # Get file name information
    file_name = os.path.basename(file_path)
    # File name format: predictions_ratio_0.100_rep1.csv
    parts = file_name.replace('predictions_ratio_', '').replace('.csv', '').split('_')
    train_ratio = float(parts[0])
    repetition = int(parts[1].replace('rep', ''))
    
    # True values
    true_values = df['true_value']
    
    # Prediction column list
    prediction_columns = [
        'global_mean_pred',
        'model_mean_pred', 
        'question_mean_pred',
        'irt_1pl_pred',
        'irt_2pl_pred',
        'weighted_irt_global',
        'weighted_irt_model',
        'weighted_irt_question'
    ]
    
    # Results storage
    results = []
    
    # Calculate metrics for each prediction column
    for pred_col in prediction_columns:
        if pred_col in df.columns:
            try:
                # Get prediction values
                pred_values = df[pred_col]
                
                # Remove NaN values
                valid_indices = ~(np.isnan(true_values) | np.isnan(pred_values))
                valid_true = true_values[valid_indices]
                valid_pred = pred_values[valid_indices]
                
                # Check if there is enough data
                if len(valid_true) == 0:
                    auc = np.nan
                    accuracy = np.nan
                else:
                    # Calculate AUC (only when positive and negative samples exist)
                    if len(np.unique(valid_true)) > 1:
                        auc = roc_auc_score(valid_true, valid_pred)
                    else:
                        auc = np.nan
                    
                    # Calculate Accuracy (using 0.5 as threshold)
                    pred_binary = (valid_pred >= 0.5).astype(int)
                    accuracy = accuracy_score(valid_true, pred_binary)
                
                # Store results
                results.append({
                    'file_name': file_name,
                    'train_ratio': train_ratio,
                    'repetition': repetition,
                    'prediction_method': pred_col,
                    'auc': auc,
                    'accuracy': accuracy,
                    'sample_count': len(valid_true)
                })
            except Exception as e:
                print(f"Error calculating metrics for {pred_col}: {e}")
                results.append({
                    'file_name': file_name,
                    'train_ratio': train_ratio,
                    'repetition': repetition,
                    'prediction_method': pred_col,
                    'auc': np.nan,
                    'accuracy': np.nan,
                    'sample_count': 0
                })
        else:
            # If column doesn't exist, add NaN results
            results.append({
                'file_name': file_name,
                'train_ratio': train_ratio,
                'repetition': repetition,
                'prediction_method': pred_col,
                'auc': np.nan,
                'accuracy': np.nan,
                'sample_count': 0
            })
    
    return pd.DataFrame(results)

def process_all_files(directory_path, file_pattern="*.csv"):
    """
    Process all CSV files in directory and calculate metrics
    
    Parameters:
    directory_path: Directory path containing CSV files
    file_pattern: File matching pattern
    
    Returns:
    DataFrame containing metrics results for all files
    """
    # Get all matching CSV files in directory
    search_pattern = os.path.join(directory_path, file_pattern)
    csv_files = glob.glob(search_pattern)
    
    if not csv_files:
        raise ValueError(f"No CSV files found in {directory_path} matching pattern '{file_pattern}'")
    
    print(f"Found {len(csv_files)} CSV files")
    
    # Store all results
    all_results = []
    
    # Process each file
    for i, file_path in enumerate(csv_files):
        try:
            print(f"Processing file {i+1}/{len(csv_files)}: {os.path.basename(file_path)}")
            file_results = calculate_metrics_for_file(file_path)
            all_results.append(file_results)
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    # Merge all results
    if all_results:
        final_results = pd.concat(all_results, ignore_index=True)
        return final_results
    else:
        return pd.DataFrame()

def save_results(results_df, output_path):
    """
    Save results to CSV file
    
    Parameters:
    results_df: DataFrame containing metric results
    output_path: Output file path
    """
    # Sort by train_ratio, repetition, prediction_method
    results_df = results_df.sort_values(['train_ratio', 'repetition', 'prediction_method'])
    
    # Save to CSV file
    results_df.to_csv(output_path, index=False)
    print(f"Results saved to: {output_path}")
    
    return results_df

def main():
    # Set up command line argument parsing
    parser = argparse.ArgumentParser(description="Calculate AUC and Accuracy metrics for prediction results")
    parser.add_argument("--directory", "-d", 
                        default="data/sample_predictions",
                        help="Directory path containing CSV files")
    parser.add_argument("--output", "-o",
                        default="results/prediction_metrics.csv",
                        help="Output CSV file path")
    parser.add_argument("--pattern", "-p", default="*.csv",
                        help="File matching pattern (default: *.csv)")
    
    args = parser.parse_args()
    
    print(f"Starting to process directory: {args.directory}")
    print(f"Output file: {args.output}")
    
    try:
        # Process all files
        results = process_all_files(args.directory, args.pattern)
        
        if not results.empty:
            # Save results
            save_results(results, args.output)
            
            # Display some statistics
            print("\nProcessing completed!")
            print(f"Total processed {len(results)} rows of metric data")
            print(f"Including {results['file_name'].nunique()} different files")
            print(f"Including {results['prediction_method'].nunique()} prediction methods")
            
            # Display first few rows of results
            print("\nFirst 10 rows of results:")
            print(results.head(10))
        else:
            print("No results generated")
            
    except Exception as e:
        print(f"Error during processing: {e}")

# Example usage function
def example_usage():
    """
    Example usage
    """
    # Set directory path
    directory_path = "data/sample_predictions"
    
    # Set output path
    output_path = "results/prediction_metrics.csv"
    
    print("Starting to calculate prediction metrics...")
    
    try:
        # Process all files
        results = process_all_files(directory_path)
        
        if not results.empty:
            # Save results
            save_results(results, output_path)
            
            # Display some statistics
            print("\nProcessing completed!")
            print(f"Total processed {len(results)} rows of metric data")
            print(f"Including {results['file_name'].nunique()} different files")
            print(f"Including {results['prediction_method'].nunique()} prediction methods")
            
            # Display average AUC and Accuracy for each prediction method
            print("\nAverage metrics for each prediction method:")
            avg_metrics = results.groupby('prediction_method')[['auc', 'accuracy']].mean()
            print(avg_metrics)
        else:
            print("No results generated")
            
    except Exception as e:
        print(f"Error during processing: {e}")

if __name__ == "__main__":
    # If running script directly, use command line arguments
    # If you want to run example, comment out main() and uncomment example_usage()
    main()
    # example_usage()