
import time
import numpy as np
from env_synth_tree import SyntheticTree
from algorithms import CATSOSelector, PATSOSelector, UCTSelector, PowerUCTSelector
from mcts_runner import MCTSRunner


def benchmark_once(algo: str, n_decisions: int = 100, n_sims: int = 1000, seed: int = 0):
    env = SyntheticTree(k=8, d=3, sigma=0.5, trans_noise=0.5, gamma=1.0, seed=seed)
    if algo == "UCT":
        selector = UCTSelector(k=env.k, ucb_c=1.0)
        p = 1.0
    elif algo == "PowerUCT":
        selector = PowerUCTSelector(k=env.k, ucb_c=1.0)
        p = 2.0  # example p
    elif algo == "CATSO":
        selector = CATSOSelector(k=env.k, C=1.5, N=100, p=1.0)
        p = 1.0
    elif algo == "PATSO":
        selector = PATSOSelector(k=env.k, C=1.5, K=200, p=1.0)
        p = 1.0
    else:
        raise ValueError
    t0 = time.perf_counter()
    for i in range(n_decisions):
        runner = MCTSRunner(env, selector, p=p, seed=seed + i)
        runner.run(n_sims)
    t1 = time.perf_counter()
    ms = 1000.0 * (t1 - t0) / n_decisions
    return ms


if __name__ == "__main__":
    algos = ["UCT", "PowerUCT", "CATSO", "PATSO"]
    for a in algos:
        ms = benchmark_once(a, n_decisions=10, n_sims=200)  # smaller defaults for demo
        print(f"{a}: {ms:.2f} ms per decision (avg over 10 decisions at 200 sims)")
