
import math
import random
from dataclasses import dataclass
from typing import Tuple, Dict, List, Optional

import numpy as np


@dataclass(frozen=True)
class STState:
    """State is encoded by the *actual* path of visited child indices from root.
    Example: () is root; (2,) means went to child #2 of root; (2,0) means next step: child 0 of that node; etc.
    """
    path: Tuple[int, ...]  # length = current depth

    @property
    def depth(self) -> int:
        return len(self.path)


class SyntheticTree:
    """
    SyntheticTree environment as specified in Sec. 5.1:
      - balanced k-ary tree of fixed depth d
      - only leaf gives stochastic reward: Gaussian(mu_leaf, sigma)
      - transition randomness: with prob (1 - trans_noise) go to the intended child (index = chosen action),
        otherwise go uniformly among the other k-1 children.
      - discount gamma (default 1.0 for this env).
    We also expose a closed-form computation of the *true* optimal root value V*(s0) by dynamic programming,
    given the sampled leaf means.
    """
    def __init__(self, k: int, d: int, sigma: float = 0.5, trans_noise: float = 0.5,
                 gamma: float = 1.0, seed: Optional[int] = None):
        assert k >= 2 and d >= 1, "Need k>=2 and d>=1"
        assert 0.0 <= trans_noise <= 1.0, "trans_noise in [0,1]"
        assert sigma >= 0.0, "sigma >= 0"
        self.k = k
        self.d = d
        self.sigma = sigma
        self.trans_noise = trans_noise
        self.gamma = gamma
        self.rng = np.random.default_rng(seed)
        self.py_rng = random.Random(seed)
        # sample leaf means mu ~ Uniform(0,1) for each leaf
        # map leaf index path (tuple of length d) -> mu
        self.leaf_means: Dict[Tuple[int, ...], float] = {}
        self._generate_leaf_means()

        # compute exact V* by DP with known means and transition noise
        self.true_V_root = self._compute_true_values_dp()

    # --------------- basic helpers ---------------
    def reset(self) -> STState:
        return STState(())

    @property
    def reward_range(self) -> Tuple[float, float]:
        # leaf rewards are drawn from Gaussian(μ, σ) with μ in [0,1]; we'll consider returns approximately in [-3σ, 1+3σ]
        lo = -3.0 * self.sigma
        hi = 1.0 + 3.0 * self.sigma
        return lo, hi

    def step(self, state: STState, action: int) -> Tuple[STState, float, bool]:
        """Simulate a *single* environment step. Intermediate rewards are zero; only leaf returns reward.
        We return (next_state, reward, done). If next state is a leaf, reward ~ N(mu_leaf, sigma); else reward=0."""
        assert 0 <= action < self.k, "invalid action"
        assert state.depth < self.d, "cannot step from a terminal leaf"
        # choose actual next child based on transition randomness
        if self.py_rng.random() <= (1.0 - self.trans_noise):
            next_child = action
        else:
            # choose uniformly among others
            candidates = [i for i in range(self.k) if i != action]
            next_child = self.py_rng.choice(candidates)
        next_state = STState(state.path + (next_child,))
        if next_state.depth == self.d:
            # terminal: emit noisy reward
            mu = self.leaf_means[next_state.path]
            r = float(self.rng.normal(mu, self.sigma))
            return next_state, r, True
        else:
            return next_state, 0.0, False

    # --------------- structure & ground-truth DP ---------------
    def _generate_leaf_means(self):
        # full k-ary tree of depth d -> k^d leaves
        # assign leaf means mu in [0,1]
        def rec(path: Tuple[int, ...]):
            if len(path) == self.d:
                self.leaf_means[path] = float(self.rng.random())
            else:
                for a in range(self.k):
                    rec(path + (a,))
        rec(tuple())

    def _compute_true_values_dp(self) -> float:
        """Compute exact optimal value function at the root by dynamic programming on the *means*.
        V(s) = max_a Q(s,a),  Q(s,a) = E[V(s')] where s' is random due to transition noise.
        At leaves (depth==d), V(leaf) = mu_leaf.
        """
        from functools import lru_cache

        p_skip = self.trans_noise
        k = self.k

        @lru_cache(None)
        def V_of(path: Tuple[int, ...]) -> float:
            depth = len(path)
            if depth == self.d:
                return self.leaf_means[path]
            # compute child values
            child_vals = [V_of(path + (a,)) for a in range(k)]
            # for each action a, compute expected value under transition randomness
            best = -1e9
            for a in range(k):
                intended = child_vals[a]
                others = (sum(child_vals) - intended) / (k - 1.0)
                q = (1.0 - p_skip) * intended + p_skip * others
                if q > best:
                    best = q
            return best

        return V_of(tuple())
