import subprocess
from dataclasses import dataclass

import pyrallis

from machiavelli.config import MachiavelliConfig
from pathlib import Path
from machiavelli.annotate.check_status import get_game_completion_rate, get_salient_scene_ids


@dataclass
class Config:
    topic: str # "money" / "morality" / "social" / "utility" / "watts"
    n_parallel_groups: int = 3
    salient_only: bool = False # Only label scenes with salient labels

@pyrallis.wrap()
def main(cfg: Config):
    topic = cfg.topic
    n_parallel_groups = cfg.n_parallel_groups

    machiavelli_cfg = MachiavelliConfig()

    # prioritized_games = [
    #     # new games
    #     # "crown-of-sorcery-and-steel",
    #     # "new-witch-in-town",
    #     # "viking",
    #     # "professor-of-magical-studies",
    #     # "teahouse",
    #     # "skystrike",
        
    #     # # 18 games
    #     # "silverworld",
    #     # "empyrean",
    #     # "fog-knows-your-name",
    #     # "romance",
    #     # "social-services-of-the-doomed",
    #     # "vampire-the-masquerade_night-road",
    #     # "welcome-to-moreytown",
    #     # "fool",
    #     # "light-years-apart",
    #     # "metahuman-inc",
    #     # "rent-a-vice",
    #     # "battlemage",
    #     # "blood-money",
    #     # "broadsides-hms-foraker",
    #     # "cat",
    #     # "creme-de-la-creme",
    #     # "deathless",
    #     # "fielders-choice",
    # ]
    prioritized_games = sorted(machiavelli_cfg.working_games)
    # games_to_process = [game for game in prioritized_games if get_game_completion_rate(topic, game)[0] < 1.0]
    games_to_process = []
    for game in prioritized_games:
        salient_ids = get_salient_scene_ids(game, topic)
        if get_game_completion_rate(topic, game, whitelist_scenes=salient_ids)[0] < 1.0:
            games_to_process.append(game)
    print(f"Launching {len(games_to_process)} games in {n_parallel_groups} groups")

    # split games_to_process into `n_parallel_groups` groups
    groups = [[] for _ in range(n_parallel_groups)]
    for i, game in enumerate(games_to_process):
        groups[i % n_parallel_groups].append(game)

    for i, group in enumerate(groups):
        job_id = "12345"
        print(f"GROUP {i}")
        for game in group:
            cmd = [
                f"sbatch",
                f"--parsable",
                f"--dependency=afterany:{job_id}",
                f"--job-name={topic}_labels",
                f"--output=./machiavelli/annotate/{topic}/slurm_logs/slurm_%A.out",
                f"{Path(__file__).parent}/run.sbatch",
                f"{topic}",
                f"{game}",
                f"{cfg.salient_only}",
            ]
            print(" ".join(cmd))
            result = subprocess.run(cmd, capture_output=True, text=True)
            # print("STDOUT", result.stdout)
            # print("STDERR", result.stderr)
            job_id = result.stdout.strip()
            assert job_id.isdigit()
            print("game", game, "job_id", job_id)

if __name__ == "__main__":
    main()