#!/usr/bin/env python3
"""
Reproduce Cross-Model Transfer experiments from the paper.

Tests whether JO precedents learned from a teacher model (GPT-5.2)
transfer zero-shot to student models from different families.

Table 6: Cross-Model Transfer Results
- Teacher: GPT-5.2 (learns precedents on 100 tasks)
- Students: GPT-4o-mini, Moonshot-v1-8k, Llama-3.1-8B, Qwen2.5-72B, Claude-3.5-Haiku
"""

import sys
import argparse
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from jo.precedent_store import PrecedentStore


def pr(*args, **kwargs):
    print(*args, **kwargs)
    sys.stdout.flush()


# Paper configuration
N_TASKS = 100
N_SEEDS = 2  # Total N = 100 * 2 = 200 per student
MAX_STEPS = 10


# Student model configurations
STUDENT_MODELS = {
    "gpt-4o-mini": {"family": "OpenAI", "api": "openai"},
    "moonshot-v1-8k": {"family": "Kimi", "api": "moonshot"},
    "llama-3.1-8b": {"family": "Meta", "api": "together"},
    "qwen2.5-72b": {"family": "Alibaba", "api": "together"},
    "claude-3.5-haiku": {"family": "Anthropic", "api": "anthropic"},
}


def load_teacher_precedents(path: str = None) -> PrecedentStore:
    """Load pre-trained precedents from GPT-5.2 teacher."""
    import json

    if path is None:
        path = Path(__file__).parent.parent / "artifacts" / "expert_gpt52_precedents.json"

    store = PrecedentStore()
    if path.exists():
        with open(path) as f:
            data = json.load(f)
        # Handle both list format and dict with "precedents" key
        precedents = data if isinstance(data, list) else data.get("precedents", [])
        store.import_from_dict(precedents)
        pr(f"Loaded {store.size()} precedents from teacher")
    else:
        pr(f"Warning: Teacher precedents not found at {path}")
    return store


def run_table6_transfer(models=None, seeds=None):
    """Table 6: Cross-Model Transfer Results."""
    pr("=" * 70)
    pr("TABLE 6: Cross-Model Transfer (GPT-5.2 Teacher)")
    pr("=" * 70)

    models = models or list(STUDENT_MODELS.keys())
    seeds = seeds or list(range(N_SEEDS))

    # Load teacher precedents
    teacher_store = load_teacher_precedents()

    for model in models:
        if model not in STUDENT_MODELS:
            pr(f"Unknown model: {model}, skipping")
            continue

        pr(f"\n--- Student: {model} ({STUDENT_MODELS[model]['family']}) ---")

        # NO baseline (no JO)
        pr(f"  NO: Running {N_TASKS * len(seeds)} tasks without JO...")
        # no_results = run_student_experiment(model, seeds, use_jo=False)
        # pr(f"  NO: Success={no_results['success_rate']:.1%}")

        # JO with transferred precedents
        pr(f"  JO: Running {N_TASKS * len(seeds)} tasks with transferred precedents...")
        # jo_results = run_student_experiment(model, seeds, use_jo=True, store=teacher_store)
        # pr(f"  JO: Success={jo_results['success_rate']:.1%}")
        # pr(f"  Delta: +{jo_results['success_rate'] - no_results['success_rate']:.1%}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Cross-Model Transfer experiments")
    parser.add_argument("--models", type=str, default=None,
                       help="Comma-separated model names (default: all)")
    parser.add_argument("--seeds", type=str, default=None,
                       help="Comma-separated seeds")
    args = parser.parse_args()

    models = args.models.split(",") if args.models else None
    seeds = [int(s) for s in args.seeds.split(",")] if args.seeds else None

    run_table6_transfer(models, seeds)
