import time, math, heapq, hashlib
from typing import Dict, Any, List, Tuple
from pam.prefix_dag import PrefixDAG, leaves_under
from config import CFG

_EULER_GAMMA = 0.57721566490153286060

def _prf_uniform_q064(leaf_id:int)->int:
    h = hashlib.sha256(f"leaf:{leaf_id}".encode()).digest()
    return int.from_bytes(h[:8], "big", signed=False)

def _u_from_q64(u64:int)->float:
    return (u64 + 0.5) / (2**64)

def realized_score(depth:int, tau:float, base:float=0.6):
    # reuse the toy deterministic part used in PaMEngine
    return (base - 0.01*depth)/tau - 0.02*depth

def leaf_realized_score(leaf_id:int, depth:int, tau:float=CFG.TAU):
    U = _u_from_q64(_prf_uniform_q064(leaf_id))
    return realized_score(depth, tau) - math.log(-math.log(1.0 - U))

def global_best_realized(g: PrefixDAG)->float:
    best = -1e9
    for n in g.nodes.values():
        if not n.children:
            for L in n.leaves:
                best = max(best, leaf_realized_score(L, n.depth))
    return best

# -------- Baseline 1: No-certificate (beamless greedy by bound) ----------
def no_certificate(g: PrefixDAG, mtau, cap:int=None)->Dict[str,Any]:
    t0=time.perf_counter()
    expanded=0; incumbent=-1e9
    pq: List[Tuple[float,str]]=[]
    heapq.heappush(pq, (-mtau.mtau("root", False), "root"))  # max-heap via neg
    while pq:
        key, nid = heapq.heappop(pq); key = -key
        expanded += 1
        node = g.nodes[nid]
        if not node.children:
            for L in node.leaves:
                incumbent = max(incumbent, leaf_realized_score(L, node.depth))
            continue
        for c in node.children:
            heapq.heappush(pq, (-mtau.mtau(c, False), c))
        if cap and expanded>=cap: break
    return {"expanded":expanded, "time":(time.perf_counter()-t0), "cert":False, "fail":False}

# -------- Baseline 2: Beam search by bound (beam width K) ---------------
def beam_search(g: PrefixDAG, mtau, beam:int=CFG.BEAM_WIDTH, cap:int=None)->Dict[str,Any]:
    t0=time.perf_counter()
    expanded=0; incumbent=-1e9
    frontier = [("root", mtau.mtau("root", False))]
    while frontier:
        # expand top-K by bound
        frontier.sort(key=lambda x: x[1], reverse=True)
        frontier = frontier[:beam]
        new=[]
        for nid, _ in frontier:
            expanded += 1
            node = g.nodes[nid]
            if not node.children:
                for L in node.leaves:
                    incumbent = max(incumbent, leaf_realized_score(L, node.depth))
                continue
            for c in node.children:
                new.append((c, mtau.mtau(c, False)))
            if cap and expanded>=cap: break
        frontier = new
        if cap and expanded>=cap: break
    return {"expanded":expanded, "time":(time.perf_counter()-t0), "cert":False, "fail":False}

# -------- Baseline 3: Distribution-level pruning (expected key) ----------
def dist_prune_search(g: PrefixDAG, mtau, cap:int=None)->Dict[str,Any]:
    """
    Uses E[-log t(v)] = gamma + log N(v) for Exp(N). Not coupled to realized race.
    Stops when max_v M_tau(v) - (gamma + log N(v)) <= incumbent. Can be UNSOUND.
    """
    t0=time.perf_counter()
    expanded=0; incumbent=-1e9
    pq: List[Tuple[float,str]]=[]
    def exp_key(nid:str)->float:
        N = max(1, g.N(nid))
        return mtau.mtau(nid, False) - (_EULER_GAMMA + math.log(N))
    heapq.heappush(pq, (-exp_key("root"), "root"))
    while pq:
        key, nid = heapq.heappop(pq); key=-key
        # stop if bound says nothing can beat incumbent
        if key <= incumbent: break
        expanded += 1
        node = g.nodes[nid]
        if not node.children:
            for L in node.leaves:
                incumbent = max(incumbent, leaf_realized_score(L, node.depth))
            continue
        for c in node.children:
            heapq.heappush(pq, (-exp_key(c), c))
        if cap and expanded>=cap: break

    # UNSOUNDNESS check: compare with true global best realized score.
    gb = global_best_realized(g)
    fail = (incumbent + 1e-9) < gb
    return {"expanded":expanded, "time":(time.perf_counter()-t0), "cert":False, "fail":fail, "incumbent":incumbent, "global_best":gb}
