
import json
from sklearn.metrics import cohen_kappa_score
from collections import defaultdict, Counter
import argparse
import itertools
import numpy as np

try:
    from statsmodels.stats.inter_rater import cohens_kappa
    STATSMODELS_AVAILABLE = True
except ImportError:
    STATSMODELS_AVAILABLE = False
    print("Warning: statsmodels not available. Weighted kappa will not be computed.")
    print("Install with: pip install statsmodels")


# ---- Domain knowledge (edit here if you change labels) ----
# Fields considered ordinal and their category orders (low -> high)
FIELD_ORDERS = {
    'information_clarity': ['Contradictory', 'Ambiguous', 'Clear'],
    'information_quality': ['Insufficient', 'Sufficient'],
    'premise_grounding'  : ['Not Grounded', 'Directly Grounded'],
}
# Values treated as missing per field (dropped from pairwise comparisons)
MISSING_TOKENS = {
    'information_quality': {'Unspecified', None, ''},
    'information_clarity': {None, ''},
    'premise_grounding'  : {None, ''},
    'type'               : {None, ''},
}


def load_data(file_path):
    """Loads annotated data from a .jsonl file keyed by 'question'."""
    data = {}
    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            record = json.loads(line)
            data[record['id']] = record
    return data


def get_annotations(record):
    """Extracts a flat list of annotations from a record."""
    annotations = []
    trace = record.get('trace', [])
    for step in trace:
        annotation = step.get('annotation', {})
        flat = {'type': annotation.get('type', step.get('type'))}
        attrs = annotation.get('attributes', {})
        if 'information_quality' in attrs:
            flat['information_quality'] = attrs['information_quality']
        if 'information_clarity' in attrs:
            flat['information_clarity'] = attrs['information_clarity']
        if 'premise_grounding' in attrs:
            flat['premise_grounding'] = attrs['premise_grounding']
        annotations.append(flat)
    return annotations


def create_contingency_table(labels1, labels2, ordered_labels):
    """Create a contingency table according to an explicit label order."""
    index = {lab: i for i, lab in enumerate(ordered_labels)}
    n = len(ordered_labels)
    table = np.zeros((n, n), dtype=int)
    for a, b in zip(labels1, labels2):
        if a in index and b in index:
            table[index[a], index[b]] += 1
    return table


def compute_weighted_kappa(labels1, labels2, ordered_labels, weight_type='linear'):
    """Compute weighted kappa using statsmodels with a fixed label order."""
    if not STATSMODELS_AVAILABLE:
        return np.nan, np.nan, np.nan
    wt = None
    if weight_type == 'linear':
        wt = 'linear'
    elif weight_type == 'quadratic':
        wt = 'quadratic'
    elif weight_type == 'unweighted':
        wt = None
    else:
        raise ValueError(f"Unknown weight_type: {weight_type}")

    table = create_contingency_table(labels1, labels2, ordered_labels)
    result = cohens_kappa(table, wt=wt, return_results=True)
    return result.kappa, result.var_kappa, result.pvalue_two_sided


def analyze_label_distribution(labels1, labels2, field_name):
    """Debug view of label distributions and first few disagreements."""
    counter1 = Counter(labels1)
    counter2 = Counter(labels2)
    print(f"\n  Label Distribution for '{field_name}':")
    print(f"    File 1: {dict(counter1)}")
    print(f"    File 2: {dict(counter2)}")
    disagreements = [(i, a, b) for i, (a, b) in enumerate(zip(labels1, labels2)) if a != b]
    if disagreements:
        print(f"    Disagreements found: {len(disagreements)}")
        print(f"    First 5 disagreements: {disagreements[:5]}")
    else:
        print("    No disagreements found")


def drop_missing_pairs(labels1, labels2, field):
    """Drop pairs where either side is a missing token for that field."""
    miss = MISSING_TOKENS.get(field, {None, ''})
    a1, a2 = [], []
    for x, y in zip(labels1, labels2):
        if x in miss or y in miss:
            continue
        a1.append(x); a2.append(y)
    return a1, a2


def main():
    parser = argparse.ArgumentParser(description="Compute pairwise Cohen's Kappa for multiple annotation files, with safe weighted kappa for ordinal fields.")
    parser.add_argument("files", nargs='+', help="Paths to the annotated .jsonl files (2 or more).")
    parser.add_argument("--debug", action="store_true", help="Show label distributions when regular kappa < 0.3.")
    parser.add_argument("--weighted", action="store_true", help="Compute weighted kappa for ORDINAL fields only.")
    parser.add_argument("--weight-type", choices=['linear', 'quadratic', 'unweighted'], default='linear', help="Weighting for weighted kappa (default: linear).")
    args = parser.parse_args()

    if len(args.files) < 2:
        print("Error: Please provide at least two files to compare.")
        return

    print("Loading data...")
    all_data = [load_data(f) for f in args.files]
    file_basenames = [f.split('/')[-1] for f in args.files]

    # common questions
    common_ids = set(all_data[0].keys())
    for i in range(1, len(all_data)):
        common_ids &= set(all_data[i].keys())
    common_ids = sorted(list(common_ids))
    if not common_ids:
        print("No common records found across all files.")
        return
    print(f"Found {len(common_ids)} common records across {len(all_data)} files.")

    kappa_scores = defaultdict(dict)
    weighted_kappa_scores = defaultdict(dict)
    weighted_kappa_vars = defaultdict(dict)
    weighted_kappa_pvalues = defaultdict(dict)
    compared_items_count = defaultdict(dict)
    mismatched_traces = set()
    
    # Track detailed misalignment statistics
    field_misalignment_stats = defaultdict(lambda: defaultdict(int))
    trace_length_stats = defaultdict(lambda: defaultdict(int))

    annotator_pairs = list(itertools.combinations(range(len(all_data)), 2))

    for rater1_idx, rater2_idx in annotator_pairs:
        labels = defaultdict(lambda: ([], []))

        for record_id in common_ids:
            record1 = all_data[rater1_idx][record_id]
            record2 = all_data[rater2_idx][record_id]

            ann1_list = get_annotations(record1)
            ann2_list = get_annotations(record2)

            # Track trace length statistics
            trace_length_stats[(rater1_idx, rater2_idx)][(len(ann1_list), len(ann2_list))] += 1

            if len(ann1_list) != len(ann2_list):
                mismatched_traces.add(record_id)
                continue

            # Count annotations per field for this record
            field_counts1 = defaultdict(int)
            field_counts2 = defaultdict(int)
            
            for ann1 in ann1_list:
                for field in ann1.keys():
                    if ann1[field] is not None:
                        field_counts1[field] += 1
                        
            for ann2 in ann2_list:
                for field in ann2.keys():
                    if ann2[field] is not None:
                        field_counts2[field] += 1

            # Track field-specific misalignments
            for field in set(field_counts1.keys()) | set(field_counts2.keys()):
                count1 = field_counts1.get(field, 0)
                count2 = field_counts2.get(field, 0)
                if count1 != count2:
                    field_misalignment_stats[(rater1_idx, rater2_idx)][field] += 1

            for ann1, ann2 in zip(ann1_list, ann2_list):
                all_fields = set(ann1.keys()) | set(ann2.keys())
                for field in all_fields:
                    label1 = ann1.get(field)
                    label2 = ann2.get(field)
                    if label1 is not None and label2 is not None:
                        labels[field][0].append(label1)
                        labels[field][1].append(label2)

        for field, (rater1_labels_raw, rater2_labels_raw) in labels.items():
            # Drop missing tokens like 'Unspecified' before computing
            rater1_labels, rater2_labels = drop_missing_pairs(rater1_labels_raw, rater2_labels_raw, field)

            if len(rater1_labels) == 0:
                kappa = np.nan
                weighted_kappa = np.nan
                weighted_var = np.nan
                weighted_pval = np.nan
                num_items = 0
            else:
                num_items = len(rater1_labels)
                # Regular (unweighted) kappa
                kappa = cohen_kappa_score(rater1_labels, rater2_labels)

                # Weighted kappa only for ordinal fields with an explicit order
                if args.weighted and STATSMODELS_AVAILABLE and field in FIELD_ORDERS:
                    ordered_labels = FIELD_ORDERS[field]
                    weighted_kappa, weighted_var, weighted_pval = compute_weighted_kappa(
                        rater1_labels, rater2_labels, ordered_labels, args.weight_type
                    )
                else:
                    weighted_kappa = np.nan
                    weighted_var = np.nan
                    weighted_pval = np.nan

            kappa_scores[field][(rater1_idx, rater2_idx)] = kappa
            weighted_kappa_scores[field][(rater1_idx, rater2_idx)] = weighted_kappa
            weighted_kappa_vars[field][(rater1_idx, rater2_idx)] = weighted_var
            weighted_kappa_pvalues[field][(rater1_idx, rater2_idx)] = weighted_pval
            compared_items_count[field][(rater1_idx, rater2_idx)] = num_items

    # Report detailed misalignment statistics
    if mismatched_traces:
        print(f"\nWarning: Skipped {len(mismatched_traces)} records due to trace length mismatch.")
        
        print("\n--- Trace Length Mismatch Details ---")
        for pair in annotator_pairs:
            r1, r2 = pair
            pair_str = f"'{file_basenames[r1]}' vs. '{file_basenames[r2]}'"
            print(f"\nPair {pair_str}:")
            for (len1, len2), count in sorted(trace_length_stats[pair].items()):
                if len1 != len2:
                    print(f"  Records with trace lengths {len1} vs {len2}: {count}")
    
    # Report field-specific misalignments
    print("\n--- Field-Specific Misalignment Details ---")
    for pair in annotator_pairs:
        r1, r2 = pair
        pair_str = f"'{file_basenames[r1]}' vs. '{file_basenames[r2]}'"
        print(f"\nPair {pair_str}:")
        for field, count in field_misalignment_stats[pair].items():
            print(f"  Field '{field}': {count} records with different annotation counts")

    print("\n--- Pairwise Cohen's Kappa Score Report ---")
    all_fields = sorted(kappa_scores.keys())
    for field in all_fields:
        print(f"\nField: '{field}'")
        pair_kappas = []
        pair_weighted_kappas = []

        for pair in annotator_pairs:
            r1, r2 = pair
            kappa = kappa_scores[field].get(pair, np.nan)
            weighted_kappa = weighted_kappa_scores[field].get(pair, np.nan)
            weighted_var = weighted_kappa_vars[field].get(pair, np.nan)
            weighted_pval = weighted_kappa_pvalues[field].get(pair, np.nan)
            num_items = compared_items_count[field].get(pair, 0)
            pair_str = f"'{file_basenames[r1]}' vs. '{file_basenames[r2]}'"

            if not np.isnan(kappa):
                print(f"  - Pair {pair_str}:")
                print(f"    Regular Kappa: {kappa:.4f} (on {num_items} items)")
                pair_kappas.append(kappa)

                # only print weighted kappa when it is computed (ordinal fields)
                if args.weighted and not np.isnan(weighted_kappa):
                    print(f"    Weighted Kappa ({args.weight_type}): {weighted_kappa:.4f}")
                    print(f"    Variance: {weighted_var:.6f}")
                    print(f"    P-value: {weighted_pval:.6f}")
                    pair_weighted_kappas.append(weighted_kappa)
                elif args.weighted and np.isnan(weighted_kappa):
                    print(f"    Weighted Kappa: N/A (nominal or no explicit order)")

                if args.debug and kappa < 0.3:
                    # show distributions after dropping missing tokens
                    # Rebuild cleaned lists for this pair/field
                    raw1, raw2 = defaultdict(list), defaultdict(list)
                    # reconstruct raw lists for debug
                    # (we need to rebuild them since we didn't store per-field-per-pair raw lists)
                    # not storing to save memory; recompute just for debug
                    labels_recompute = ([], [])
                    for record_id in common_ids:
                        r1_rec = all_data[r1][record_id]
                        r2_rec = all_data[r2][record_id]
                        ann1_list = get_annotations(r1_rec)
                        ann2_list = get_annotations(r2_rec)
                        if len(ann1_list) != len(ann2_list):
                            continue
                        for a1, a2 in zip(ann1_list, ann2_list):
                            if field in a1 and field in a2:
                                labels_recompute[0].append(a1[field])
                                labels_recompute[1].append(a2[field])
                    a1_clean, a2_clean = drop_missing_pairs(labels_recompute[0], labels_recompute[1], field)
                    analyze_label_distribution(a1_clean, a2_clean, field)
            else:
                print(f"  - Pair {pair_str}: Skipped (Not enough label diversity)")

        if pair_kappas:
            print(f"  --------------------------------------------------")
            print(f"  => Average Pairwise Regular Kappa for '{field}': {np.mean(pair_kappas):.4f}")
            if pair_weighted_kappas:
                print(f"  => Average Pairwise Weighted Kappa ({args.weight_type}) for '{field}': {np.mean(pair_weighted_kappas):.4f}")


if __name__ == "__main__":
    main()
