import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.metrics import cohen_kappa_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

def normalize_correctness_value(value):
    """
    Normalize correctness values to a common format (1/0)
    Handles: 0/1, 'Correct'/'Incorrect', True/False
    """
    if pd.isna(value):
        return np.nan
    
    # Convert to string and strip whitespace
    str_value = str(value).strip().lower()
    
    # Handle different representations
    if str_value in ['1', '1.0', 'correct', 'true']:
        return 1
    elif str_value in ['0', '0.0', 'incorrect', 'false']:
        return 0
    else:
        print(f"Warning: Unrecognized value '{value}' - treating as NaN")
        return np.nan

def extract_file_category(filename):
    """
    Extract category from filename based on suffix
    """
    filename_lower = filename.lower()
    
    if filename_lower.endswith('irrelevant'):
        return 'Irrelevant'
    elif filename_lower.endswith('counterfactual'):
        return 'CounterFactual'
    elif filename_lower.endswith('charflipped'):
        return 'CharFlipped'
    elif filename_lower.endswith('wordflipped'):
        return 'WordFlipped'
    else:
        return 'Other'

def calculate_agreement_metrics(y_true, y_pred, metric_name):
    """
    Calculate percentage agreement and Cohen's kappa
    """
    # Remove NaN values
    mask = ~(np.isnan(y_true) | np.isnan(y_pred))
    y_true_clean = y_true[mask]
    y_pred_clean = y_pred[mask]
    
    if len(y_true_clean) == 0:
        return {
            'metric': metric_name,
            'total_pairs': 0,
            'valid_pairs': 0,
            'percentage_agreement': np.nan,
            'cohens_kappa': np.nan,
            'confusion_matrix': None
        }
    
    # Calculate percentage agreement
    agreement = np.sum(y_true_clean == y_pred_clean)
    total = len(y_true_clean)
    percentage_agreement = (agreement / total) * 100
    
    # Calculate Cohen's kappa
    try:
        kappa = cohen_kappa_score(y_true_clean, y_pred_clean)
    except Exception as e:
        print(f"Warning: Could not calculate Cohen's kappa for {metric_name}: {e}")
        print(f"  Data summary - Human: {np.unique(y_true_clean, return_counts=True)}")
        print(f"  Data summary - Model: {np.unique(y_pred_clean, return_counts=True)}")
        print(f"  Valid pairs: {len(y_true_clean)}")
        kappa = np.nan
    
    # Create confusion matrix
    try:
        cm = confusion_matrix(y_true_clean, y_pred_clean, labels=[0, 1])
    except Exception:
        cm = None
    
    return {
        'metric': metric_name,
        'total_pairs': len(y_true),
        'valid_pairs': total,
        'percentage_agreement': percentage_agreement,
        'cohens_kappa': kappa,
        'confusion_matrix': cm
    }

def interpret_kappa(kappa_value):
    """
    Interpret Cohen's kappa value according to Landis & Koch (1977)
    """
    if np.isnan(kappa_value):
        return "Cannot calculate"
    elif kappa_value < 0:
        return "Poor (worse than chance)"
    elif kappa_value <= 0.20:
        return "Slight"
    elif kappa_value <= 0.40:
        return "Fair"
    elif kappa_value <= 0.60:
        return "Moderate"
    elif kappa_value <= 0.80:
        return "Substantial"
    else:
        return "Almost perfect"

def plot_category_comparison(category_results, save_path=None):
    """
    Plot comparison of agreement metrics across categories
    """
    # Prepare data for plotting
    categories = list(category_results.keys())
    accuracy_agreement = [category_results[cat]['accuracy_agreement'] for cat in categories]
    accuracy_kappa = [category_results[cat]['accuracy_kappa'] for cat in categories]
    grammar_agreement = [category_results[cat]['grammar_agreement'] for cat in categories]
    grammar_kappa = [category_results[cat]['grammar_kappa'] for cat in categories]
    
    # Create subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot percentage agreements
    ax1.bar(categories, accuracy_agreement, color='skyblue', alpha=0.7)
    ax1.set_title('Accuracy - Percentage Agreement by Category')
    ax1.set_ylabel('Percentage Agreement (%)')
    ax1.set_ylim(0, 100)
    ax1.tick_params(axis='x', rotation=45)
    
    ax2.bar(categories, grammar_agreement, color='lightgreen', alpha=0.7)
    ax2.set_title('Grammar - Percentage Agreement by Category')
    ax2.set_ylabel('Percentage Agreement (%)')
    ax2.set_ylim(0, 100)
    ax2.tick_params(axis='x', rotation=45)
    
    # Plot Cohen's kappa
    ax3.bar(categories, accuracy_kappa, color='orange', alpha=0.7)
    ax3.set_title('Accuracy - Cohen\'s Kappa by Category')
    ax3.set_ylabel('Cohen\'s Kappa')
    ax3.set_ylim(-0.2, 1.0)
    ax3.tick_params(axis='x', rotation=45)
    
    ax4.bar(categories, grammar_kappa, color='pink', alpha=0.7)
    ax4.set_title('Grammar - Cohen\'s Kappa by Category')
    ax4.set_ylabel('Cohen\'s Kappa')
    ax4.set_ylim(-0.2, 1.0)
    ax4.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def compare_human_model_by_categories(folder1_path, folder2_path, output_dir=None):
    """
    Compare human annotations with model predictions, categorized by file types
    """
    
    folder1 = Path(folder1_path)
    folder2 = Path(folder2_path)
    
    if not folder1.exists() or not folder2.exists():
        raise FileNotFoundError("One or both folder paths do not exist")
    
    if output_dir:
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
    
    # Get all CSV files from both folders
    folder1_files = {f.stem: f for f in folder1.glob('*.csv')}
    folder2_files = {f.stem.replace('Evaluated_', ''): f for f in folder2.glob('Evaluated_*.csv')}
    
    print(f"Found {len(folder1_files)} human annotation files")
    print(f"Found {len(folder2_files)} model prediction files")
    
    # Find matching files
    common_files = set(folder1_files.keys()) & set(folder2_files.keys())
    print(f"Found {len(common_files)} matching file pairs")
    
    if not common_files:
        print("No matching files found!")
        return
    
    # Categorize files
    categorized_files = defaultdict(list)
    for file_name in common_files:
        category = extract_file_category(file_name)
        categorized_files[category].append(file_name)
    
    print("\nFile categorization:")
    for category, files in categorized_files.items():
        print(f"  {category}: {len(files)} files")
    
    # Store results
    all_file_results = []
    category_data = defaultdict(lambda: {
        'accuracy_human': [], 'accuracy_model': [],
        'grammar_human': [], 'grammar_model': [],
        'files': [], 'total_pairs': 0
    })
    
    # Process each file
    for file_name in sorted(common_files):
        category = extract_file_category(file_name)
        
        print(f"\n{'='*60}")
        print(f"Processing: {file_name} (Category: {category})")
        print('='*60)
        
        # Read files
        try:
            human_df = pd.read_csv(folder1_files[file_name])
            model_df = pd.read_csv(folder2_files[file_name])
        except Exception as e:
            print(f"Error reading files for {file_name}: {e}")
            continue
        
        # Check required columns
        required_cols_human = ['No', 'is_correct', 'grammar_correct']
        required_cols_model = ['accuracy', 'grammatical_correctness']
        
        missing_cols_human = [col for col in required_cols_human if col not in human_df.columns]
        missing_cols_model = [col for col in required_cols_model if col not in model_df.columns]
        
        if missing_cols_human or missing_cols_model:
            if missing_cols_human:
                print(f"Missing columns in human file: {missing_cols_human}")
            if missing_cols_model:
                print(f"Missing columns in model file: {missing_cols_model}")
            continue
        
        # Sort model predictions by index
        model_df_sorted = model_df.sort_index().reset_index(drop=True)
        
        print(f"Human annotations: {len(human_df)} rows")
        print(f"Model predictions: {len(model_df_sorted)} rows")
        
        # Align data based on 'No' column
        accuracy_pairs = []
        grammar_pairs = []
        
        for idx, row in human_df.iterrows():
            no_value = row['No']
            model_position = int(no_value) - 1  # Convert to 0-indexed
            
            if 0 <= model_position < len(model_df_sorted):
                model_row = model_df_sorted.iloc[model_position]
                
                # Normalize values
                human_accuracy = normalize_correctness_value(row['is_correct'])
                human_grammar = normalize_correctness_value(row['grammar_correct'])
                model_accuracy = normalize_correctness_value(model_row['accuracy'])
                model_grammar = normalize_correctness_value(model_row['grammatical_correctness'])
                
                accuracy_pairs.append((human_accuracy, model_accuracy))
                grammar_pairs.append((human_grammar, model_grammar))
        
        # Convert to numpy arrays
        human_acc = np.array([pair[0] for pair in accuracy_pairs])
        model_acc = np.array([pair[1] for pair in accuracy_pairs])
        human_gram = np.array([pair[0] for pair in grammar_pairs])
        model_gram = np.array([pair[1] for pair in grammar_pairs])
        
        # Calculate metrics for this file
        accuracy_metrics = calculate_agreement_metrics(
            human_acc, model_acc, f"{file_name}_accuracy"
        )
        grammar_metrics = calculate_agreement_metrics(
            human_gram, model_gram, f"{file_name}_grammar"
        )
        
        # Store individual file results
        file_results = {
            'file': file_name,
            'category': category,
            'total_pairs': len(accuracy_pairs),
            'accuracy_agreement': accuracy_metrics['percentage_agreement'],
            'accuracy_kappa': accuracy_metrics['cohens_kappa'],
            'accuracy_interpretation': interpret_kappa(accuracy_metrics['cohens_kappa']),
            'grammar_agreement': grammar_metrics['percentage_agreement'],
            'grammar_kappa': grammar_metrics['cohens_kappa'],
            'grammar_interpretation': interpret_kappa(grammar_metrics['cohens_kappa'])
        }
        all_file_results.append(file_results)
        
        # Add to category data for aggregated metrics
        valid_acc_mask = ~(np.isnan(human_acc) | np.isnan(model_acc))
        valid_gram_mask = ~(np.isnan(human_gram) | np.isnan(model_gram))
        
        category_data[category]['accuracy_human'].extend(human_acc[valid_acc_mask])
        category_data[category]['accuracy_model'].extend(model_acc[valid_acc_mask])
        category_data[category]['grammar_human'].extend(human_gram[valid_gram_mask])
        category_data[category]['grammar_model'].extend(model_gram[valid_gram_mask])
        category_data[category]['files'].append(file_name)
        category_data[category]['total_pairs'] += len(accuracy_pairs)
        
        # Print results for this file
        print(f"\nACCURACY/CORRECTNESS ALIGNMENT:")
        print(f"  Percentage Agreement: {accuracy_metrics['percentage_agreement']:.2f}%")
        print(f"  Cohen's Kappa: {accuracy_metrics['cohens_kappa']:.4f}")
        print(f"  Interpretation: {interpret_kappa(accuracy_metrics['cohens_kappa'])}")
        
        print(f"\nGRAMMATICAL CORRECTNESS ALIGNMENT:")
        print(f"  Percentage Agreement: {grammar_metrics['percentage_agreement']:.2f}%")
        print(f"  Cohen's Kappa: {grammar_metrics['cohens_kappa']:.4f}")
        print(f"  Interpretation: {interpret_kappa(grammar_metrics['cohens_kappa'])}")
    
    # Calculate category-wise aggregated metrics
    print(f"\n{'='*80}")
    print("CATEGORY-WISE AGGREGATED RESULTS")
    print('='*80)
    
    category_results = {}
    category_summary = []
    
    for category, data in category_data.items():
        if not data['files']:
            continue
            
        print(f"\n{'-'*50}")
        print(f"CATEGORY: {category}")
        print(f"Files: {len(data['files'])}")
        print(f"Total pairs: {data['total_pairs']}")
        print('-'*50)
        
        # Calculate aggregated metrics
        if data['accuracy_human'] and data['accuracy_model']:
            acc_human = np.array(data['accuracy_human'])
            acc_model = np.array(data['accuracy_model'])
            acc_metrics = calculate_agreement_metrics(acc_human, acc_model, f"{category}_accuracy")
            
            print(f"\nACCURACY/CORRECTNESS (Aggregated):")
            print(f"  Valid pairs: {acc_metrics['valid_pairs']}")
            print(f"  Percentage Agreement: {acc_metrics['percentage_agreement']:.2f}%")
            print(f"  Cohen's Kappa: {acc_metrics['cohens_kappa']:.4f}")
            print(f"  Interpretation: {interpret_kappa(acc_metrics['cohens_kappa'])}")
        else:
            acc_metrics = {'percentage_agreement': np.nan, 'cohens_kappa': np.nan}
        
        if data['grammar_human'] and data['grammar_model']:
            gram_human = np.array(data['grammar_human'])
            gram_model = np.array(data['grammar_model'])
            gram_metrics = calculate_agreement_metrics(gram_human, gram_model, f"{category}_grammar")
            
            print(f"\nGRAMMATICAL CORRECTNESS (Aggregated):")
            print(f"  Valid pairs: {gram_metrics['valid_pairs']}")
            print(f"  Percentage Agreement: {gram_metrics['percentage_agreement']:.2f}%")
            print(f"  Cohen's Kappa: {gram_metrics['cohens_kappa']:.4f}")
            print(f"  Interpretation: {interpret_kappa(gram_metrics['cohens_kappa'])}")
        else:
            gram_metrics = {'percentage_agreement': np.nan, 'cohens_kappa': np.nan}
        
        # Store category results
        category_results[category] = {
            'files_count': len(data['files']),
            'total_pairs': data['total_pairs'],
            'accuracy_agreement': acc_metrics['percentage_agreement'],
            'accuracy_kappa': acc_metrics['cohens_kappa'],
            'accuracy_interpretation': interpret_kappa(acc_metrics['cohens_kappa']),
            'grammar_agreement': gram_metrics['percentage_agreement'],
            'grammar_kappa': gram_metrics['cohens_kappa'],
            'grammar_interpretation': interpret_kappa(gram_metrics['cohens_kappa'])
        }
        
        category_summary.append({
            'category': category,
            'files_count': len(data['files']),
            'total_pairs': data['total_pairs'],
            'accuracy_agreement': acc_metrics['percentage_agreement'],
            'accuracy_kappa': acc_metrics['cohens_kappa'],
            'accuracy_interpretation': interpret_kappa(acc_metrics['cohens_kappa']),
            'grammar_agreement': gram_metrics['percentage_agreement'],
            'grammar_kappa': gram_metrics['cohens_kappa'],
            'grammar_interpretation': interpret_kappa(gram_metrics['cohens_kappa'])
        })
    
    # Create and save summary tables
    if all_file_results:
        # Individual file results
        file_summary_df = pd.DataFrame(all_file_results)
        
        # Category summary results
        category_summary_df = pd.DataFrame(category_summary)
        
        print(f"\n{'='*80}")
        print("INDIVIDUAL FILE RESULTS SUMMARY")
        print('='*80)
        print(file_summary_df.to_string(index=False))
        
        print(f"\n{'='*80}")
        print("CATEGORY SUMMARY")
        print('='*80)
        print(category_summary_df.to_string(index=False))
        
        if output_dir:
            # Save individual file results
            file_summary_df.to_csv(output_path / 'individual_file_results.csv', index=False)
            
            # Save category summary
            category_summary_df.to_csv(output_path / 'category_summary.csv', index=False)
            
            # Create visualization
            if len(category_results) > 1:
                plot_category_comparison(category_results, output_path / 'category_comparison.png')
            
            print(f"\nResults saved to:")
            print(f"  - Individual files: {output_path / 'individual_file_results.csv'}")
            print(f"  - Category summary: {output_path / 'category_summary.csv'}")
            if len(category_results) > 1:
                print(f"  - Visualization: {output_path / 'category_comparison.png'}")
    
    return {
        'individual_files': all_file_results,
        'category_summary': category_summary,
        'category_results': category_results
    }

# Example usage
if __name__ == "__main__":
    # Replace these paths with your actual folder paths
    human_annotations_folder = "path_to_files_with_human_evaluated_files"  
    model_predictions_folder = "path_to_files_with_model_evaluated_files"  
    
    # Optional: specify output directory for results and plots
    output_directory = "category_agreement_analysis"
    
    try:
        results = compare_human_model_by_categories(
            human_annotations_folder, 
            model_predictions_folder, 
            output_directory
        )
    except Exception as e:
        print(f"Error: {e}")
        print("\nPlease make sure to:")
        print("1. Update the folder paths in the script")
        print("2. Install required packages: pip install pandas numpy scikit-learn matplotlib seaborn")
        print("3. Ensure both folders exist and contain matching CSV files")
        print("4. File names should end with: Irrelevant, CounterFactual, CharFlipped, or WordFlipped")