
from __future__ import annotations
import time
from typing import Dict, Tuple, Optional

import numpy as np

from env_synth_tree import SyntheticTree, STState
from mcts_core import NodeStats, QEdgeStats
from algorithms import BaseSelector, CATSOSelector, PATSOSelector, ScalarTSOptSelector, UCTSelector, PowerUCTSelector


class MCTSRunner:
    """
    Glue between the environment and the generic MCTS tree with specific selector (algorithm).
    This runner handles the simulate() loop, updates distributions, and applies power-mean backups.
    """
    def __init__(self, env: SyntheticTree, selector: BaseSelector, p: float, seed: Optional[int] = None):
        self.env = env
        self.selector = selector
        self.p = p
        self.nodes: Dict[Tuple[int, ...], NodeStats] = {}
        self.rng = np.random.default_rng(seed)

    # ---- public ----
    def run(self, n_sims: int):
        for _ in range(n_sims):
            self._simulate_once()

    def root_value_estimate(self) -> float:
        root = self.nodes.get(tuple())
        return root.v_value if root is not None else 0.0

    # ---- internal ----
    def _simulate_once(self):
        path_nodes = []  # list of (node_path, action_chosen)
        rewards = []     # list of immediate rewards per edge (0 except last step)
        state = self.env.reset()
        node_path = state.path

        # Selection & Expansion
        while state.depth < self.env.d:
            node = self.nodes.setdefault(node_path, NodeStats())
            # prepare edges
            for a in range(self.env.k):
                qe = node.ensure_edge(a)
                self.selector.prepare_edge(qe, self.env.reward_range)
            # choose action
            a = self.selector.select_action(node, self.rng)
            next_state, r, done = self.env.step(state, a)
            path_nodes.append((node_path, a))
            rewards.append(r)
            state = next_state
            node_path = state.path
            if done:
                break

        # Now backup from leaf to root
        # For this env, only the final step yields non-zero reward.
        assert len(path_nodes) == len(rewards) >= 1
        G = 0.0
        # We go backwards
        for (node_key, a), r in zip(reversed(path_nodes), reversed(rewards)):
            # discounted return at this edge:
            G = r + self.env.gamma * G
            # update edge stats
            node = self.nodes[node_key]
            qe = node.edges[a]
            qe.visits += 1
            # update distribution
            if qe.cat is not None:
                qe.cat.update(G)
            if qe.part is not None:
                qe.part.update(G)
            if qe.scalar is not None:
                qe.scalar.update(G)

            # update node visits and V-backup
            node.visits += 1
            # recompute v_value via power-mean over expected Q's (weights = T_{s,a}/T_s)
            BaseSelector.compute_v_backup(node, p=self.p)
