import heapq, math, random, time, hashlib
from dataclasses import dataclass
from typing import Tuple, List, Optional
from pam.prefix_dag import PrefixDAG
from pam.race import q064_to_uniform, draw_q064, exp_from_uniform
from pam.mtau import MTau
from pam.rdp import rdp_gaussian, best_eps
from dp_ledger.logger import DPLedger
from config import CFG

@dataclass
class KeyItem:
    key: float
    nid: str
    data: Tuple
    def __lt__(self, other): return self.key > other.key   # max-heap

class PaMEngine:
    def __init__(self, g: PrefixDAG, mtau: MTau, mode:str="Surrogate", use_r2=False, rng_seed:int=42):
        assert mode in ("Exact","Surrogate","Fallback")
        self.g, self.mtau, self.mode, self.use_r2 = g, mtau, mode, use_r2
        random.seed(rng_seed)
        self.sigmas_realized: List[float] = []
        self.ledger = DPLedger(mode=mode, claim_type=("RunWiseExact" if mode!="Fallback" else "NoCert"),
                               graph_meta={"nodes":len(g.nodes)}, seed=rng_seed)

        # --- Adapter metadata (one per run, for the table) ---
        self.adapter_meta = None
        if mode in ("Exact", "Surrogate"):
            idx = 0 if mode == "Exact" else 1  # Small for Exact, Medium for Surrogate
            meta = CFG.ADAPTER_CATALOG[idx]
            self.adapter_meta = meta
            self.ledger.set_model_meta(
                model_id=meta["tier"],
                adapter_id=meta["name"],
                dp_cert_id=meta["dp_cert_id"],
                eps_train=str(meta["eps_train"]),
                delta_train=str(meta["delta_train"]),
            )
        else:
            # Fallback: no adapter (keeps claim_type=NoCert semantics straightforward)
            self.ledger.set_model_meta(
                model_id="None",
                adapter_id="",
                dp_cert_id="",
                eps_train="",
                delta_train=""
            )
    @staticmethod
    def _key_exact(Mtau, minus_log_t): return Mtau - minus_log_t
    @staticmethod
    def _key_surrogate(Mtau, minus_log_that): return Mtau - minus_log_that

    def _choose_sigma(self, nid:str)->float:
        d = self.g.nodes[nid].depth
        return min(3.0, 1.2 + 0.3*d)

    def _leaf_uniform_q064(self, leaf_id:int)->int:
        # deterministic PRF: sha256("leaf:<id>") → take first 8 bytes as uint64
        h = hashlib.sha256(f"leaf:{leaf_id}".encode()).digest()
        return int.from_bytes(h[:8], "big", signed=False)

    def run(self, eps_max:float=CFG.EPS_MAX, price_max:int=CFG.PRICE_MAX_CENTS, slo_ms:int=CFG.SLO_MS,
            Nub_factor:float=1.2, cap_k:int=None):
        t0 = time.time()
        pq: List[KeyItem] = []
        root="root"

        # Root initialize
        if self.mode=="Exact":
            Nroot = self.g.N(root); Nub = Nroot
            u64 = draw_q064(); U = q064_to_uniform(u64); self.ledger.log_uniform(root, u64)
            t_root = -math.log(1.0-U)/max(Nroot,1)
            self.ledger.log_counts(root,Nroot,Nub)
            key0 = self._key_exact(self.mtau.mtau(root, self.use_r2), -math.log(t_root))
            heapq.heappush(pq, KeyItem(key0, root, ("exact", t_root)))
        elif self.mode=="Surrogate":
            Nroot = self.g.N(root); Nub = max(1,int(math.ceil(Nroot*Nub_factor)))
            u64 = draw_q064(); U = q064_to_uniform(u64); self.ledger.log_uniform(root, u64)
            that = -math.log(1.0-U)/Nub
            self.ledger.log_counts(root,Nroot,Nub)
            key0 = self._key_surrogate(self.mtau.mtau(root, self.use_r2), -math.log(that))
            heapq.heappush(pq, KeyItem(key0, root, ("sur", that)))
        else:
            heapq.heappush(pq, KeyItem(self.mtau.mtau(root, self.use_r2), root, ("fallback", None)))
            self.ledger.claim_type="NoCert"

        incumbent = float("-inf")
        best_leaf: Optional[int]=None
        expanded=0
        price_spent=0; eps_used=0.0
        stop_key = None
        while pq:
            it = heapq.heappop(pq); nid=it.nid
            if self.mode in ("Exact", "Surrogate") and it.key <= incumbent:
                stop_key = it.key
                break
            self.ledger.log_route_pop(nid)
            Mt = self.mtau.mtau(nid, self.use_r2)
            # Stop rules: if current max frontier key ≤ incumbent, stop
            if self.mode in ("Exact","Surrogate") and it.key <= incumbent: break

            expanded+=1
            if cap_k is not None and expanded > cap_k:
                self.ledger.add_guard("CapExceeded")
                break

            node = self.g.nodes[nid]

            # Leaf: materialize realized scores (PRF-based uniforms)
            if not node.children:
                for P in node.leaves:
                    u64 = self._leaf_uniform_q064(P)
                    U = q064_to_uniform(u64)
                    realized = ((0.6 - 0.01*node.depth)/CFG.TAU - 0.02*node.depth) - math.log(-math.log(1.0-U))
                    if realized > incumbent:
                        incumbent = realized; best_leaf=P
                continue

            # Log counts at this node for κ-tightening stats (Surrogate can use Nub_factor)
            N_here = self.g.N(nid)
            Nub_here = N_here if self.mode=="Exact" else max(1,int(math.ceil(N_here*Nub_factor)))
            self.ledger.log_counts(nid, N_here, Nub_here)

            # Expand children per mode
            if self.mode=="Exact":
                weights=[self.g.N(c) for c in node.children]
                tot = sum(weights) if sum(weights)>0 else len(weights)
                r = random.uniform(0, tot); acc=0; win_idx=0
                for i,w in enumerate(weights if sum(weights)>0 else [1]*len(weights)):
                    acc += w
                    if r <= acc: win_idx=i; break
                for i,c in enumerate(node.children):
                    if weights[i]==0: continue
                    if i==win_idx:
                        t_child = it.data[1]
                    else:
                        u64 = draw_q064(); U=q064_to_uniform(u64); self.ledger.log_uniform(c,u64)
                        t_child = it.data[1] + exp_from_uniform(U, max(self.g.N(c),1))
                    minus_log = -math.log(t_child)
                    key = self._key_exact(self.mtau.mtau(c,self.use_r2), minus_log)
                    heapq.heappush(pq, KeyItem(key, c, ("exact", t_child)))
                    sigma = self._choose_sigma(c); self.sigmas_realized.append(sigma)
                    self.ledger.log_edge(f"{nid}->{c}", sigma, rdp_gaussian(16, sigma), CFG.DELTA)

            elif self.mode=="Surrogate":
                that_v = it.data[1]
                for c in node.children:
                    minus_log_that = -math.log(that_v)
                    key = self._key_surrogate(self.mtau.mtau(c,self.use_r2), minus_log_that)
                    heapq.heappush(pq, KeyItem(key, c, ("sur", that_v)))
                    sigma = self._choose_sigma(c); self.sigmas_realized.append(sigma)
                    self.ledger.log_edge(f"{nid}->{c}", sigma, rdp_gaussian(16, sigma), CFG.DELTA)

            else:  # Fallback
                for c in node.children:
                    heapq.heappush(pq, KeyItem(self.mtau.mtau(c,self.use_r2), c, ("fallback", None)))

            price_spent += 5
            if price_spent > CFG.PRICE_MAX_CENTS:
                self.ledger.set_budget(eps_used, price_spent, int((time.time()-t0)*1000), event="BudgetFail")
                self.ledger.claim_type="NoCert"
                break

        # Router-internal RDP bookkeeping (for transparency only)
        if self.sigmas_realized:
            eps, alpha = best_eps(self.sigmas_realized, CFG.DELTA)
            self.ledger.set_router_eps(eps, alpha)
        self.ledger.set_budget(eps_used, price_spent, int((time.time()-t0)*1000))
        stop_slack = 0.0
        if 'stop_key' in locals() and stop_key is not None:
            stop_slack = max(0.0, stop_key - incumbent)
        path = self.ledger.save()
        return {
            "incumbent": incumbent,
            "best_leaf": best_leaf,
            "ledger": path,
            "expanded": expanded,
            "stop_slack": stop_slack,
        }