"""
Build processed CSV (turkey_cost_table.csv) from annotations.json.

Aggregates multi-annotator labels per image and computes signed Δ
based on annotator agreement for binary injured vs not_injured.

Usage: python -m src.data.preprocess_turkey
"""
import json
import pandas as pd
import numpy as np
from collections import defaultdict
from pathlib import Path

RAW_JSON = "data/turkey/annotations.json"
OUT_CSV = "data/turkey_cost_table.csv"


def main():
    print(f"Loading {RAW_JSON}...")
    with open(RAW_JSON) as f:
        annotations = json.load(f)

    print(f"  Loaded {len(annotations):,} annotator batches")

    # Flatten nested structure: each batch has an 'annotations' list
    all_annotations = []
    for batch in annotations:
        if 'annotations' in batch:
            all_annotations.extend(batch['annotations'])

    print(f"  Flattened to {len(all_annotations):,} individual annotations")

    # Group by image_path
    votes = defaultdict(lambda: {"injured": 0, "not_injured": 0})

    for ann in all_annotations:
        img_path = ann["image_path"]
        label = ann["class_label"]

        if label in ("plumage_injury", "head_injury"):
            votes[img_path]["injured"] += 1
        elif label == "not_injured":
            votes[img_path]["not_injured"] += 1
        else:
            print(f"  WARNING: Unknown label '{label}' for {img_path}")

    print(f"  Found {len(votes):,} unique images")

    # Build rows
    rows = []
    for img_path, counts in votes.items():
        n_yes = counts["injured"]
        n_no = counts["not_injured"]

        # Compute signed delta with smoothing
        delta_signed = np.log((n_yes + 1) / (n_no + 1))
        abs_delta = abs(delta_signed)

        # Majority label (with tie-breaking: ties → positive)
        y_star = int(delta_signed >= 0)

        rows.append({
            "image_path": img_path,
            "n_yes": n_yes,
            "n_no": n_no,
            "delta_signed": delta_signed,
            "abs_delta": abs_delta,
            "y_star": y_star
        })

    df = pd.DataFrame(rows)

    # Save
    df.to_csv(OUT_CSV, index=False)

    print(f"\n✓ Wrote {OUT_CSV}")
    print(f"  Rows: {len(df):,}")
    print(f"  Mean annotations per image: {(df['n_yes'] + df['n_no']).mean():.1f}")
    print(f"  Mean |Δ|: {df['abs_delta'].mean():.3f}")
    print(f"  Max |Δ|: {df['abs_delta'].max():.3f}")
    print(f"  Positive class (injured): {df['y_star'].mean():.1%}")


if __name__ == "__main__":
    main()
