import json
import statistics
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from pathlib import Path

import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf

from pcot.prompts.rsm import BASELINE_GAME_PROMPT, RSM_GAME_PROMPT, SIMPLE_GAME_PROMPT
from pcot.tasks.rsp.bot import BotFromScript, GreenbergBot, IocaineBot
from pcot.tasks.rsp.contest import Contest
from pcot.tasks.rsp.llm_bot import LLMBot
from pcot.tasks.rsp.rsp_dojo_adapter import DojoBot

logger = getLogger(__name__)


def run_single_match(cfg: DictConfig, sequence_id: int, output_dir: Path) -> dict:
    """Runs a single sequence of matches and saves the result."""
    logger.info(f"Starting RSP Match Sequence {sequence_id}")

    if cfg.experiment.prompt == "rsm":
        system_prompt = RSM_GAME_PROMPT
    elif cfg.experiment.prompt == "baseline":
        system_prompt = BASELINE_GAME_PROMPT
    elif cfg.experiment.prompt == "simple":
        system_prompt = SIMPLE_GAME_PROMPT
    else:
        raise RuntimeError(f"Invalid prompt type {cfg.experiment.prompt}")

    llm_bot = LLMBot(
        litellm_name=cfg.model.litellm_model_name,
        system_prompt=system_prompt,
        show_history=cfg.experiment.show_history,
    )

    kaggle_dir = (Path(__file__).parent / "rsp_dojo_agents").resolve()
    if (kaggle_dir / f"{cfg.experiment.bot_name}.py").exists():
        script_path = str(kaggle_dir / f"{cfg.experiment.bot_name}.py")
        script_bot = DojoBot(script_path)
    elif cfg.experiment.bot_name == "iocaine":
        script_bot = IocaineBot()
    elif cfg.experiment.bot_name == "greenberg":
        script_bot = GreenbergBot()
    else:
        d = (Path(__file__).parent / "bot_functions").resolve()
        script_bot = BotFromScript(name=str(d / f"{cfg.experiment.bot_name}.py"))

    logger.info(f"Starting match with model: {cfg.model.litellm_model_name}")
    contest = Contest(bot1=llm_bot, bot2=script_bot, rounds=cfg.experiment.num_matches)
    contest_result = contest.run()
    logger.info(f"Result for sequence {sequence_id}: {contest_result}")

    result = contest_result.to_dict()
    result["output_history"] = llm_bot.output_history

    # Save individual result
    result_path = output_dir / f"result_{sequence_id}.json"
    result_path.write_text(json.dumps(result, indent=2))
    logger.info(f"Result for sequence {sequence_id} saved to {result_path}")

    return result


@hydra.main(
    config_path="../../../../scripts/conf/rsp",
    config_name="config",
    version_base=None,
)
def main(cfg: DictConfig) -> None:
    logger.info("Starting RSP Match")
    hydra_output_dir = Path(HydraConfig.get().runtime.output_dir)
    logger.info(f"Hydra Run Output Directory: {hydra_output_dir}")
    logger.info("Loaded configuration:\n%s", OmegaConf.to_yaml(cfg))

    num_episodes = cfg.experiment.get("num_episodes", 10)

    if num_episodes > 1:
        logger.info(f"Running {num_episodes} sequences concurrently.")
        all_results = []
        with ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(run_single_match, cfg, i, hydra_output_dir)
                for i in range(num_episodes)
            ]
            for future in futures:
                all_results.append(future.result())

        scores: list[float] = []
        for result in all_results:
            if "score" in result:
                scores.append(result["score"])

        aggregated_result = {
            "score_ave": statistics.mean(scores),
            "score_std": statistics.stdev(scores),
            "scores": scores,
            "valid_matches": len(scores),
        }
        aggregated_result["individual_results"] = all_results

        # Save aggregated result
        agg_result_path = hydra_output_dir / "aggregated_result.json"
        agg_result_path.write_text(json.dumps(aggregated_result, indent=2))
        logger.info(f"Aggregated results saved to {agg_result_path}")

    else:
        run_single_match(cfg, 0, hydra_output_dir)

    logger.info("RSP completed successfully")


if __name__ == "__main__":
    main()
