#!/usr/bin/env python3
"""
Reproduce Overcooked experiments from the paper.

Experiments:
- Table 3: Constraint enforcement (NO vs JO)
- Table 4: Composition invariance (K1 vs K2)
- Table 7: External/Internal baselines (Reflexion, CRITIC, LlamaGuard, RBR, Centralized Prompt)
"""

import sys
import argparse
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from overcooked.constraints import TaskSpec
from overcooked.experiment import ExperimentConfig, run_experiment, print_experiment_report
from overcooked.adapter import create_adapter


# Paper configuration (Table 1, matching appendix F.4)
LAYOUT = "cramped_room"
HORIZON = 80
N_EPISODES = 20
N_SEEDS = 10  # Total N = 20 * 10 = 200
MODEL = "gpt-4o-mini"


def run_table3_enforcement(seeds=None):
    """Table 3: Constraint enforcement across domains (Overcooked row)."""
    print("=" * 70)
    print("TABLE 3: Constraint Enforcement (Overcooked)")
    print("=" * 70)

    seeds = seeds or list(range(N_SEEDS))
    task_spec = TaskSpec(
        max_consecutive_stays=1,  # T4
        enforce_plate_timing=True,  # H3
    )

    for condition in ["NO", "JO_dynamic"]:
        print(f"\n--- Condition: {condition} ---")
        for seed in seeds:
            config = ExperimentConfig(
                name=f"table3_{condition}_s{seed}",
                condition=condition,
                task_spec=task_spec,
                layout=LAYOUT,
                n_episodes=N_EPISODES,
                horizon=HORIZON,
                model=MODEL,
                seed=seed,
            )
            result = run_experiment(config, verbose=False)
            print(f"  Seed {seed}: VR={result.mean_vr:.1%}, Success={result.success_rate:.1%}")


def run_table4_composition(seeds=None):
    """Table 4: Composition invariance (K1 vs K2)."""
    print("=" * 70)
    print("TABLE 4: Composition Invariance")
    print("=" * 70)

    seeds = seeds or list(range(5))
    task_spec = TaskSpec(
        max_consecutive_stays=1,
        enforce_plate_timing=True,
    )

    for n_agents in [1, 2]:  # K1, K2
        for condition in ["NO", "JO_dynamic"]:
            print(f"\n--- K{n_agents}, {condition} ---")
            for seed in seeds:
                config = ExperimentConfig(
                    name=f"table4_K{n_agents}_{condition}_s{seed}",
                    condition=condition,
                    task_spec=task_spec,
                    layout=LAYOUT,
                    n_episodes=10,
                    horizon=HORIZON,
                    model=MODEL,
                    n_agents=n_agents,
                    seed=seed,
                )
                result = run_experiment(config, verbose=False)
                print(f"  Seed {seed}: VR={result.mean_vr:.1%}")


def run_table7_baselines(seeds=None):
    """Table 7: External/Internal baselines (Reflexion, CRITIC, LlamaGuard, RBR, Centralized Prompt)."""
    print("=" * 70)
    print("TABLE 7: External/Internal Baselines")
    print("=" * 70)

    seeds = seeds or list(range(3))
    task_spec = TaskSpec(
        max_consecutive_stays=1,  # T4
        enforce_plate_timing=True,  # H3
    )

    # Compare internal self-correction baselines
    for agent_type in ["reflexion", "critic", "llamaguard"]:
        print(f"\n--- Baseline: {agent_type} ---")
        for seed in seeds:
            config = ExperimentConfig(
                name=f"table7_{agent_type}_s{seed}",
                condition="NO",  # Internal baselines run without JO
                task_spec=task_spec,
                layout=LAYOUT,
                n_episodes=N_EPISODES,
                horizon=HORIZON,
                model=MODEL,
                agent_type=agent_type,
                seed=seed,
            )
            result = run_experiment(config, verbose=False)
            print(f"  Seed {seed}: VR={result.mean_vr:.1%}, Success={result.success_rate:.1%}")

    # RBR (Rule-Based Repair) baseline
    print(f"\n--- Baseline: RBR ---")
    for seed in seeds:
        config = ExperimentConfig(
            name=f"table7_rbr_s{seed}",
            condition="JO_static",  # Rule-based without learning
            task_spec=task_spec,
            layout=LAYOUT,
            n_episodes=N_EPISODES,
            horizon=HORIZON,
            model=MODEL,
            seed=seed,
        )
        result = run_experiment(config, verbose=False)
        print(f"  Seed {seed}: VR={result.mean_vr:.1%}, Success={result.success_rate:.1%}")

    # Centralized Prompt baseline
    print(f"\n--- Baseline: Centralized Prompt ---")
    for seed in seeds:
        config = ExperimentConfig(
            name=f"table7_centralized_s{seed}",
            condition="NO",  # No JO, constraints embedded in prompt
            task_spec=task_spec,
            layout=LAYOUT,
            n_episodes=N_EPISODES,
            horizon=HORIZON,
            model=MODEL,
            agent_type="centralized",
            seed=seed,
        )
        result = run_experiment(config, verbose=False)
        print(f"  Seed {seed}: VR={result.mean_vr:.1%}, Success={result.success_rate:.1%}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Overcooked experiments")
    parser.add_argument("--experiment", type=str, default="all",
                       choices=["all", "table3", "table4", "table7"],
                       help="Which experiment to run")
    parser.add_argument("--seeds", type=str, default=None,
                       help="Comma-separated seeds (default: use paper config)")
    args = parser.parse_args()

    seeds = [int(s) for s in args.seeds.split(",")] if args.seeds else None

    if args.experiment in ["all", "table3"]:
        run_table3_enforcement(seeds)
    if args.experiment in ["all", "table4"]:
        run_table4_composition(seeds)
    if args.experiment in ["all", "table7"]:
        run_table7_baselines(seeds)
