#!/usr/bin/env python3
"""
Calculate inter-annotator agreement between three annotation files.
Calculates agreement at the example level where:
  - "True Success": All criteria are passed (Yes)
  - "False Positive": Any criteria is failed (No or No - ...)

Includes agreement rate, Fleiss' kappa statistic, and majority vote analysis.
"""

import json
import glob
from collections import defaultdict
from pathlib import Path

# Base paths for data directory
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / "data"


def parse_example_id(example_id):
    """
    Parse an example_id to extract components.
    Format: task_id:perturbation_id:refinement_model:execution_agent
    """
    parts = example_id.split(':')
    if len(parts) >= 4:
        task_id = parts[0]
        perturbation_id = parts[1]
        refinement_model = parts[2]
        execution_agent = parts[3]
        return {
            'task_id': task_id,
            'perturbation_id': perturbation_id,
            'refinement_model': refinement_model,
            'execution_agent': execution_agent
        }
    return None


def find_trajectory_evaluation(example_id):
    """
    Find the trajectory_evaluation.json file for a given example_id.
    Returns the severity assessment if found.
    """
    parsed = parse_example_id(example_id)
    if not parsed:
        return None
    
    # Search in both multi_apps_test and os directories
    for domain in ['multi_apps_test', 'os']:
        domain_path = DATA_DIR / domain
        if not domain_path.exists():
            continue
        
        # Search for the task directory
        pattern = str(domain_path / "**" / parsed['task_id'] / f"perturbed_query_{parsed['perturbation_id']}" / 
                     f"iterative_refinement_{parsed['refinement_model']}" / f"agent_{parsed['execution_agent']}" / 
                     "trajectory" / "trajectory_evaluation.json")
        
        matches = glob.glob(pattern, recursive=True)
        if matches:
            try:
                with open(matches[0], 'r') as f:
                    return json.load(f)
            except:
                return None
    
    return None


def calculate_fleiss_kappa(ratings_matrix, n_raters=3):
    """
    Calculate Fleiss' kappa for inter-rater reliability.
    
    Args:
        ratings_matrix: List of dicts, where each dict maps categories to 
                       the number of raters who chose that category for that item.
        n_raters: Number of raters (default 3)
    
    Returns:
        Fleiss' kappa value, or None if calculation not possible
    """
    if not ratings_matrix:
        return None
    
    n_items = len(ratings_matrix)
    
    # Get all unique categories
    all_categories = set()
    for item in ratings_matrix:
        all_categories.update(item.keys())
    categories = sorted(all_categories)
    n_categories = len(categories)
    
    if n_categories < 2:
        return None  # Need at least 2 categories
    
    # Build the count matrix: n_items x n_categories
    # Each cell contains the number of raters who assigned that category
    matrix = []
    for item in ratings_matrix:
        row = [item.get(cat, 0) for cat in categories]
        matrix.append(row)
    
    # Calculate P_i for each item (proportion of agreeing pairs)
    P_i_values = []
    for row in matrix:
        # P_i = (1 / (n*(n-1))) * sum(n_ij * (n_ij - 1))
        sum_sq = sum(n_ij * (n_ij - 1) for n_ij in row)
        P_i = sum_sq / (n_raters * (n_raters - 1)) if n_raters > 1 else 0
        P_i_values.append(P_i)
    
    # P_bar = mean of P_i values (observed agreement)
    P_bar = sum(P_i_values) / n_items if n_items > 0 else 0
    
    # Calculate p_j for each category (proportion of all assignments to category j)
    total_assignments = n_items * n_raters
    p_j_values = []
    for j, cat in enumerate(categories):
        count_j = sum(matrix[i][j] for i in range(n_items))
        p_j = count_j / total_assignments if total_assignments > 0 else 0
        p_j_values.append(p_j)
    
    # P_e = sum of p_j^2 (expected agreement by chance)
    P_e = sum(p_j ** 2 for p_j in p_j_values)
    
    # Fleiss' kappa = (P_bar - P_e) / (1 - P_e)
    if P_e == 1:
        return 1.0 if P_bar == 1 else 0.0
    
    kappa = (P_bar - P_e) / (1 - P_e)
    
    return kappa


def interpret_kappa(kappa):
    """Return interpretation of kappa value based on Landis & Koch (1977)."""
    if kappa is None:
        return "N/A"
    if kappa < 0:
        return "Poor (less than chance)"
    elif kappa < 0.21:
        return "Slight"
    elif kappa < 0.41:
        return "Fair"
    elif kappa < 0.61:
        return "Moderate"
    elif kappa < 0.81:
        return "Substantial"
    else:
        return "Almost Perfect"


def load_annotations(filepath):
    """Load annotations from JSONL file."""
    annotations = {}
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:  # Skip empty lines
                continue
            data = json.loads(line)
            annotations[data['example_id']] = data
    return annotations


def get_example_classification(example, criteria_fields):
    """
    Classify an example as 'True Success' or 'False Positive'.
    
    - True Success: All applicable criteria are 'Yes'
    - False Positive: Any criteria is 'No' (or starts with 'No')
    
    Returns:
        tuple: (classification, list of criteria decisions)
    """
    decisions = []
    has_no = False
    
    for criteria in criteria_fields:
        val = example.get(criteria, '')
        
        # Skip N/A or empty values
        if val in ['N/A', '']:
            continue
        
        decisions.append((criteria, val))
        
        # Check if this is a "No" decision (including variants like "No - Adversarial/Unrealistic")
        if val.startswith('No'):
            has_no = True
    
    if not decisions:
        return None, []  # No applicable criteria
    
    classification = 'False Positive' if has_no else 'True Success'
    return classification, decisions


def compare_annotations(ann1, ann2, ann3):
    """Compare three annotation sets and calculate example-level agreement."""
    criteria_fields = [
        'criteria_1_trajectory_analysis_failure_decision',
        'criteria_2_elicitation_evaluation_failure_decision',
        'criteria_3_perturbation_evaluation_failure_decision',
        'criteria_4_cot_monitoring_decision'
    ]
    
    criteria_names = {
        'criteria_1_trajectory_analysis_failure_decision': 'Criteria 1: Trajectory Analysis',
        'criteria_2_elicitation_evaluation_failure_decision': 'Criteria 2: Elicitation Evaluation',
        'criteria_3_perturbation_evaluation_failure_decision': 'Criteria 3: Perturbation Evaluation',
        'criteria_4_cot_monitoring_decision': 'Criteria 4: CoT Monitoring'
    }
    
    # Find common examples
    common_ids = set(ann1.keys()) & set(ann2.keys()) & set(ann3.keys())
    
    if not common_ids:
        print("ERROR: No common examples found between the three files!")
        return None
    
    # Track example-level results
    example_results = {}
    ratings_matrix = []  # For Fleiss' kappa calculation
    
    full_agreement = 0  # All 3 agree
    majority_agreement = 0  # At least 2 agree
    no_agreement = 0  # All 3 disagree (only possible with 3+ raters and 3+ categories)
    total_examples = 0
    
    # Track criteria-level statistics (for informational purposes)
    criteria_stats = defaultdict(lambda: {'full_agree': 0, 'partial_agree': 0, 'no_agree': 0, 'total': 0})
    
    # Track agreement patterns
    agreement_patterns = defaultdict(int)
    
    # Track agreement by execution agent and severity
    agreement_by_agent = defaultdict(lambda: {'full_agree': 0, 'majority_agree': 0, 'no_agree': 0, 'total': 0})
    agreement_by_severity = defaultdict(lambda: {'full_agree': 0, 'majority_agree': 0, 'no_agree': 0, 'total': 0})
    
    for example_id in sorted(common_ids):
        example1 = ann1[example_id]
        example2 = ann2[example_id]
        example3 = ann3[example_id]
        
        # Get classifications for each annotator
        class1, decisions1 = get_example_classification(example1, criteria_fields)
        class2, decisions2 = get_example_classification(example2, criteria_fields)
        class3, decisions3 = get_example_classification(example3, criteria_fields)
        
        # Skip examples where any annotator has no applicable criteria
        if class1 is None or class2 is None or class3 is None:
            continue
        
        total_examples += 1
        
        # Build ratings matrix entry for this example
        item_ratings = defaultdict(int)
        item_ratings[class1] += 1
        item_ratings[class2] += 1
        item_ratings[class3] += 1
        ratings_matrix.append(dict(item_ratings))
        
        # Determine agreement level
        classifications = [class1, class2, class3]
        unique_classifications = set(classifications)
        
        if len(unique_classifications) == 1:
            # Full agreement (all 3 agree)
            full_agreement += 1
            majority_agreement += 1
            agreement_level = 'full'
            majority_vote = class1
        elif len(unique_classifications) == 2:
            # Partial agreement (2 out of 3 agree)
            majority_agreement += 1
            agreement_level = 'partial'
            # Find the majority classification
            from collections import Counter
            count = Counter(classifications)
            majority_vote = count.most_common(1)[0][0]
        else:
            # No agreement (all different) - rare with only 2 categories
            no_agreement += 1
            agreement_level = 'none'
            majority_vote = None
        
        # Track agreement pattern
        pattern = tuple(sorted(classifications))
        agreement_patterns[pattern] += 1
        
        # Get execution agent and severity for this example
        parsed = parse_example_id(example_id)
        execution_agent = parsed['execution_agent'] if parsed else 'unknown'
        
        traj_eval = find_trajectory_evaluation(example_id)
        severity = traj_eval.get('severity_assessment', 'unknown') if traj_eval else 'unknown'
        
        # Track agreement by agent
        agreement_by_agent[execution_agent]['total'] += 1
        if agreement_level == 'full':
            agreement_by_agent[execution_agent]['full_agree'] += 1
            agreement_by_agent[execution_agent]['majority_agree'] += 1
        elif agreement_level == 'partial':
            agreement_by_agent[execution_agent]['majority_agree'] += 1
        else:
            agreement_by_agent[execution_agent]['no_agree'] += 1
        
        # Track agreement by severity
        agreement_by_severity[severity]['total'] += 1
        if agreement_level == 'full':
            agreement_by_severity[severity]['full_agree'] += 1
            agreement_by_severity[severity]['majority_agree'] += 1
        elif agreement_level == 'partial':
            agreement_by_severity[severity]['majority_agree'] += 1
        else:
            agreement_by_severity[severity]['no_agree'] += 1
        
        # Find criteria-level disagreements for detail
        criteria_disagreements = []
        decisions1_dict = dict(decisions1)
        decisions2_dict = dict(decisions2)
        decisions3_dict = dict(decisions3)
        
        all_criteria = set(decisions1_dict.keys()) | set(decisions2_dict.keys()) | set(decisions3_dict.keys())
        for criteria in all_criteria:
            val1 = decisions1_dict.get(criteria)
            val2 = decisions2_dict.get(criteria)
            val3 = decisions3_dict.get(criteria)
            
            if val1 is not None and val2 is not None and val3 is not None:
                criteria_stats[criteria]['total'] += 1
                vals = [val1, val2, val3]
                unique_vals = set(vals)
                
                if len(unique_vals) == 1:
                    criteria_stats[criteria]['full_agree'] += 1
                elif len(unique_vals) == 2:
                    criteria_stats[criteria]['partial_agree'] += 1
                else:
                    criteria_stats[criteria]['no_agree'] += 1
                
                if len(unique_vals) > 1:
                    criteria_disagreements.append({
                        'criteria': criteria,
                        'criteria_name': criteria_names.get(criteria, criteria),
                        'annotator1': val1,
                        'annotator2': val2,
                        'annotator3': val3
                    })
        
        example_results[example_id] = {
            'annotator1_classification': class1,
            'annotator2_classification': class2,
            'annotator3_classification': class3,
            'annotator1_decisions': decisions1,
            'annotator2_decisions': decisions2,
            'annotator3_decisions': decisions3,
            'agreement_level': agreement_level,
            'majority_vote': majority_vote,
            'criteria_disagreements': criteria_disagreements
        }
    
    # Calculate Fleiss' kappa
    kappa = calculate_fleiss_kappa(ratings_matrix, n_raters=3)
    
    # Count distribution for each annotator
    ann1_true_success = sum(1 for r in example_results.values() if r['annotator1_classification'] == 'True Success')
    ann1_false_positive = sum(1 for r in example_results.values() if r['annotator1_classification'] == 'False Positive')
    ann2_true_success = sum(1 for r in example_results.values() if r['annotator2_classification'] == 'True Success')
    ann2_false_positive = sum(1 for r in example_results.values() if r['annotator2_classification'] == 'False Positive')
    ann3_true_success = sum(1 for r in example_results.values() if r['annotator3_classification'] == 'True Success')
    ann3_false_positive = sum(1 for r in example_results.values() if r['annotator3_classification'] == 'False Positive')
    
    return {
        'total_examples': total_examples,
        'full_agreement': full_agreement,
        'majority_agreement': majority_agreement,
        'no_agreement': no_agreement,
        'example_results': example_results,
        'kappa': kappa,
        'ann1_distribution': {'True Success': ann1_true_success, 'False Positive': ann1_false_positive},
        'ann2_distribution': {'True Success': ann2_true_success, 'False Positive': ann2_false_positive},
        'ann3_distribution': {'True Success': ann3_true_success, 'False Positive': ann3_false_positive},
        'criteria_names': criteria_names,
        'criteria_stats': dict(criteria_stats),
        'agreement_patterns': dict(agreement_patterns),
        'agreement_by_agent': dict(agreement_by_agent),
        'agreement_by_severity': dict(agreement_by_severity)
    }


def print_summary(results, ann1_path, ann2_path, ann3_path, output_file=None):
    """Print a formatted summary of the agreement analysis.
    
    Args:
        results: Dictionary containing agreement analysis results
        ann1_path: Path to first annotator file
        ann2_path: Path to second annotator file
        ann3_path: Path to third annotator file
        output_file: Optional file object to write output to
    """
    
    def output(text=""):
        """Helper function to print to console and optionally write to file."""
        print(text)
        if output_file:
            output_file.write(text + "\n")
    
    output("=" * 80)
    output("INTER-ANNOTATOR AGREEMENT ANALYSIS (Example-Level, 3 Annotators)")
    output("=" * 80)
    output(f"Annotator 1: {ann1_path.name}")
    output(f"Annotator 2: {ann2_path.name}")
    output(f"Annotator 3: {ann3_path.name}")
    output()
    output("Classification Logic:")
    output("  - TRUE SUCCESS: All applicable criteria passed (Yes)")
    output("  - FALSE POSITIVE: Any criteria failed (No or No - ...)")
    output()
    
    total = results['total_examples']
    full_agree = results['full_agreement']
    majority_agree = results['majority_agreement']
    no_agree = results['no_agreement']
    
    full_agreement_pct = (full_agree / total * 100) if total > 0 else 0
    majority_agreement_pct = (majority_agree / total * 100) if total > 0 else 0
    
    output("-" * 80)
    output("OVERALL METRICS")
    output("-" * 80)
    output(f"Total Examples:           {total}")
    output(f"Full Agreement (3/3):     {full_agree} ({full_agreement_pct:.1f}%)")
    output(f"Majority Agreement (≥2/3): {majority_agree} ({majority_agreement_pct:.1f}%)")
    output(f"No Agreement:             {no_agree}")
    output()
    
    kappa = results.get('kappa')
    if kappa is not None:
        output(f"Fleiss' Kappa:            {kappa:.3f} ({interpret_kappa(kappa)})")
    else:
        output(f"Fleiss' Kappa:            N/A")
    output()
    
    # Agreement patterns
    output("-" * 80)
    output("AGREEMENT PATTERNS")
    output("-" * 80)
    patterns = results['agreement_patterns']
    for pattern, count in sorted(patterns.items(), key=lambda x: -x[1]):
        pattern_str = " | ".join(pattern)
        pct = (count / total * 100) if total > 0 else 0
        output(f"  {pattern_str}: {count} ({pct:.1f}%)")
    output()
    
    # Agreement by Execution Agent
    output("-" * 80)
    output("AGREEMENT BY EXECUTION AGENT")
    output("-" * 80)
    agreement_by_agent = results.get('agreement_by_agent', {})
    for agent in sorted(agreement_by_agent.keys()):
        stats = agreement_by_agent[agent]
        agent_total = stats['total']
        full_pct = (stats['full_agree'] / agent_total * 100) if agent_total > 0 else 0
        majority_pct = (stats['majority_agree'] / agent_total * 100) if agent_total > 0 else 0
        output(f"{agent}:")
        output(f"  Total: {agent_total}")
        output(f"  Full Agreement (3/3): {stats['full_agree']}/{agent_total} ({full_pct:.1f}%)")
        output(f"  Majority Agreement (≥2/3): {stats['majority_agree']}/{agent_total} ({majority_pct:.1f}%)")
    output()
    
    # Agreement by Severity Level
    output("-" * 80)
    output("AGREEMENT BY SEVERITY LEVEL")
    output("-" * 80)
    agreement_by_severity = results.get('agreement_by_severity', {})
    for severity in sorted(agreement_by_severity.keys()):
        stats = agreement_by_severity[severity]
        sev_total = stats['total']
        full_pct = (stats['full_agree'] / sev_total * 100) if sev_total > 0 else 0
        majority_pct = (stats['majority_agree'] / sev_total * 100) if sev_total > 0 else 0
        output(f"{severity}:")
        output(f"  Total: {sev_total}")
        output(f"  Full Agreement (3/3): {stats['full_agree']}/{sev_total} ({full_pct:.1f}%)")
        output(f"  Majority Agreement (≥2/3): {stats['majority_agree']}/{sev_total} ({majority_pct:.1f}%)")
    output()
    
    # Distribution breakdown
    output("-" * 80)
    output("CLASSIFICATION DISTRIBUTION")
    output("-" * 80)
    ann1_dist = results['ann1_distribution']
    ann2_dist = results['ann2_distribution']
    ann3_dist = results['ann3_distribution']
    
    output(f"{'Classification':<20} {'Annotator 1':>15} {'Annotator 2':>15} {'Annotator 3':>15}")
    output(f"{'-'*20} {'-'*15} {'-'*15} {'-'*15}")
    output(f"{'True Success':<20} {ann1_dist['True Success']:>15} {ann2_dist['True Success']:>15} {ann3_dist['True Success']:>15}")
    output(f"{'False Positive':<20} {ann1_dist['False Positive']:>15} {ann2_dist['False Positive']:>15} {ann3_dist['False Positive']:>15}")
    output()
    
    # Per-criteria agreement breakdown (informational only)
    output("-" * 80)
    output("PER-CRITERIA AGREEMENT (Informational)")
    output("-" * 80)
    
    criteria_names = results['criteria_names']
    criteria_stats = results.get('criteria_stats', {})
    
    for criteria in sorted(criteria_stats.keys()):
        stats = criteria_stats[criteria]
        name = criteria_names.get(criteria, criteria)
        total_crit = stats['total']
        full_agree_crit = stats['full_agree']
        partial_agree_crit = stats['partial_agree']
        no_agree_crit = stats['no_agree']
        full_pct = (full_agree_crit / total_crit * 100) if total_crit > 0 else 0
        partial_pct = (partial_agree_crit / total_crit * 100) if total_crit > 0 else 0
        
        output(f"{name}:")
        output(f"  Full Agreement (3/3): {full_agree_crit}/{total_crit} ({full_pct:.1f}%)")
        output(f"  Partial Agreement (2/3): {partial_agree_crit}/{total_crit} ({partial_pct:.1f}%)")
        output(f"  No Agreement: {no_agree_crit}/{total_crit}")
    output()
    
    # Examples with full agreement
    full_agreed_examples = {ex_id: res for ex_id, res in results['example_results'].items() 
                           if res['agreement_level'] == 'full'}
    
    output("-" * 80)
    output(f"EXAMPLES WITH FULL AGREEMENT (3/3): {len(full_agreed_examples)}/{total}")
    output("-" * 80)
    
    # Group by classification
    agreed_true_success = [(ex_id, res) for ex_id, res in full_agreed_examples.items() 
                           if res['annotator1_classification'] == 'True Success']
    agreed_false_positive = [(ex_id, res) for ex_id, res in full_agreed_examples.items() 
                             if res['annotator1_classification'] == 'False Positive']
    
    if agreed_true_success:
        output(f"\n  All classified as TRUE SUCCESS ({len(agreed_true_success)}):")
        for ex_id, res in sorted(agreed_true_success)[:10]:  # Show first 10
            task_id = ex_id.split(':')[0][:8]
            output(f"    ✓ {task_id}... ({ex_id})")
        if len(agreed_true_success) > 10:
            output(f"    ... and {len(agreed_true_success) - 10} more")
    
    if agreed_false_positive:
        output(f"\n  All classified as FALSE POSITIVE ({len(agreed_false_positive)}):")
        for ex_id, res in sorted(agreed_false_positive)[:10]:  # Show first 10
            task_id = ex_id.split(':')[0][:8]
            output(f"    ✗ {task_id}... ({ex_id})")
        if len(agreed_false_positive) > 10:
            output(f"    ... and {len(agreed_false_positive) - 10} more")
    output()
    
    # Examples with partial agreement
    partial_agreed_examples = {ex_id: res for ex_id, res in results['example_results'].items() 
                              if res['agreement_level'] == 'partial'}
    
    output("-" * 80)
    output(f"EXAMPLES WITH PARTIAL AGREEMENT (2/3): {len(partial_agreed_examples)}/{total}")
    output("-" * 80)
    
    # Show ALL partial agreement cases
    for ex_id, res in sorted(partial_agreed_examples.items()):
        task_id = ex_id.split(':')[0][:8]
        output(f"\n⚠ {task_id}... ({ex_id})")
        output(f"    Annotator 1: {res['annotator1_classification']}")
        output(f"    Annotator 2: {res['annotator2_classification']}")
        output(f"    Annotator 3: {res['annotator3_classification']}")
        output(f"    Majority Vote: {res['majority_vote']}")
        
        if res['criteria_disagreements']:
            output(f"    Criteria-level disagreements:")
            for disagree in res['criteria_disagreements']:  # Show ALL criteria disagreements
                output(f"      - {disagree['criteria_name']}:")
                output(f"          Ann1: {disagree['annotator1']}")
                output(f"          Ann2: {disagree['annotator2']}")
                output(f"          Ann3: {disagree['annotator3']}")
    output()
    
    # Examples with no agreement (if any)
    no_agreed_examples = {ex_id: res for ex_id, res in results['example_results'].items() 
                         if res['agreement_level'] == 'none'}
    
    if no_agreed_examples:
        output("-" * 80)
        output(f"EXAMPLES WITH NO AGREEMENT (0/3): {len(no_agreed_examples)}/{total}")
        output("-" * 80)
        
        for ex_id, res in sorted(no_agreed_examples.items()):
            task_id = ex_id.split(':')[0][:8]
            output(f"\n⚠⚠ {task_id}... ({ex_id})")
            output(f"    Annotator 1: {res['annotator1_classification']}")
            output(f"    Annotator 2: {res['annotator2_classification']}")
            output(f"    Annotator 3: {res['annotator3_classification']}")
        output()
    
    output("=" * 80)


def main():
    # File paths
    base_dir = Path(__file__).parent
    ann1_path = base_dir / ""
    ann2_path = base_dir / ""
    ann3_path = base_dir / ""  # Update this filename
    output_path = base_dir / "outputs/opus_only/agreement_summary_final.txt"
    
    print(f"Loading annotations from:")
    print(f"  File 1: {ann1_path}")
    print(f"  File 2: {ann2_path}")
    print(f"  File 3: {ann3_path}")
    print()
    
    # Load annotations
    ann1 = load_annotations(ann1_path)
    ann2 = load_annotations(ann2_path)
    ann3 = load_annotations(ann3_path)
    
    print(f"Loaded {len(ann1)} annotations from File 1")
    print(f"Loaded {len(ann2)} annotations from File 2")
    print(f"Loaded {len(ann3)} annotations from File 3")
    print()
    
    # Compare
    results = compare_annotations(ann1, ann2, ann3)
    
    if results:
        # Print summary to console and write to file
        with open(output_path, 'w', encoding='utf-8') as f:
            print_summary(results, ann1_path, ann2_path, ann3_path, output_file=f)
        
        print()
        print(f"Summary has been written to: {output_path}")
        print()


if __name__ == "__main__":
    main()