#!/usr/bin/env python3
"""
Extract corrected tail accuracy for ShapeNet-55 from WandB.

Uses threshold-based definition: tail = classes with <80 samples (11 classes)
instead of the incorrect 1/3 percentile split (18 classes).

Usage:
    python extract_corrected_tail_acc.py <wandb_run_path>

Example:
    python extract_corrected_tail_acc.py entity/project/run_id
"""

import argparse

import wandb

# ShapeNet-55 true tail classes (<80 samples) - 0-indexed
# Classes: 39(76), 33(75), 21(74), 45(68), 2(66), 7(58), 23(58), 34(53), 20(52), 43(52), 15(44)
SHAPENET55_TRUE_TAIL_CLASSES = [39, 33, 21, 45, 2, 7, 23, 34, 20, 43, 15]


def get_corrected_tail_acc_at_best_oa(run, tail_class_indices):
    """
    Query run history and return corrected tail accuracy at the epoch with best OA.

    Args:
        run: WandB run object
        tail_class_indices: List of class indices to include in tail accuracy

    Returns:
        dict with 'best_oa', 'best_epoch', 'corrected_tail_acc', 'per_class_accs'
        or None if data not available
    """
    # Build keys to query
    keys = ["val/oa", "_step"]
    for idx in tail_class_indices:
        keys.append(f"val/class_acc/{idx}")

    try:
        history = run.scan_history(keys=keys)
        history_list = list(history)

        if not history_list:
            print("  Warning: No history data found")
            return None

        # Find row with best OA
        best_row = None
        best_oa = -1
        best_step = -1

        for row in history_list:
            oa = row.get("val/oa")
            if oa is not None and oa > best_oa:
                best_oa = oa
                best_row = row
                best_step = row.get("_step", -1)

        if best_row is None:
            print("  Warning: No valid OA found in history")
            return None

        # Extract per-class accuracies at best OA epoch
        per_class_accs = {}
        for idx in tail_class_indices:
            key = f"val/class_acc/{idx}"
            acc = best_row.get(key)
            per_class_accs[idx] = acc

        # Compute corrected tail accuracy (average of available classes)
        valid_accs = [acc for acc in per_class_accs.values() if acc is not None]

        if not valid_accs:
            print("  Warning: No per-class accuracy data found")
            return None

        corrected_tail_acc = sum(valid_accs) / len(valid_accs)

        return {
            "best_oa": best_oa,
            "best_step": best_step,
            "corrected_tail_acc": corrected_tail_acc,
            "per_class_accs": per_class_accs,
            "num_classes_found": len(valid_accs),
        }

    except Exception as e:
        print(f"  Error querying history: {e}")
        return None


def main():
    parser = argparse.ArgumentParser(
        description="Extract corrected tail accuracy for ShapeNet-55 from WandB"
    )
    parser.add_argument("run_path", help="WandB run path (e.g., entity/project/run_id)")
    args = parser.parse_args()

    api = wandb.Api()

    print(f"Fetching run: {args.run_path}")
    try:
        run = api.run(args.run_path)
    except Exception as e:
        print(f"Error: Could not fetch run - {e}")
        return

    print(f"Run name: {run.name}")
    print(f"Run state: {run.state}")
    print()

    # Get corrected tail accuracy
    result = get_corrected_tail_acc_at_best_oa(run, SHAPENET55_TRUE_TAIL_CLASSES)

    if result is None:
        print("Failed to extract metrics")
        return

    print("=" * 60)
    print("RESULTS")
    print("=" * 60)
    print(
        f"Best OA:              {result['best_oa']:.4f} (at step {result['best_step']})"
    )
    print(f"Corrected Tail Acc:   {result['corrected_tail_acc']:.4f}")
    print(
        f"Classes included:     {result['num_classes_found']}/{len(SHAPENET55_TRUE_TAIL_CLASSES)}"
    )
    print()

    # Show per-class breakdown
    print("Per-class accuracies (true tail classes):")
    for idx in SHAPENET55_TRUE_TAIL_CLASSES:
        acc = result["per_class_accs"].get(idx)
        if acc is not None:
            print(f"  Class {idx:2d}: {acc:.4f}")
        else:
            print(f"  Class {idx:2d}: N/A")


if __name__ == "__main__":
    main()
