import math
from dataclasses import dataclass
from typing import Dict, List, Sequence, Union
import pandas as pd
import numpy as np
bn_metrics = {
  "asia":       {"nodes": 8,  "arcs": 8,   "avg_degree": 2.00},
  "alarm":      {"nodes": 37, "arcs": 46,  "avg_degree": 2.49},
  "cancer":     {"nodes": 5,  "arcs": 4,   "avg_degree": 1.60},
  "child":      {"nodes": 20, "arcs": 25,  "avg_degree": 2.50},
  "hailfinder": {"nodes": 56, "arcs": 66,  "avg_degree": 2.36},
  "hepar2":     {"nodes": 70, "arcs": 123, "avg_degree": 3.51},
  "insurance":  {"nodes": 27, "arcs": 52,  "avg_degree": 3.85},
  "mildew":     {"nodes": 35, "arcs": 46,  "avg_degree": 2.63},
  "water":      {"nodes": 32, "arcs": 66,  "avg_degree": 4.12},
  "win95pts":   {"nodes": 76, "arcs": 112, "avg_degree": 2.95},
}

alpha = 0.01
beta  = 0.10
k     = 3

def comb(n, k):
    if k < 0 or k > n: 
        return 0
    return math.comb(n, k)

def subset_counts_by_size(M, k):
    k = min(k, M)
    totals, incZs, notZs = [], [], []
    for ell in range(k+1):
        tot  = comb(M, ell)
        incZ = comb(M-1, ell-1) if (ell >= 1 and M >= 1) else 0
        notZ = tot - incZ
        totals.append(tot); incZs.append(incZ); notZs.append(notZ)
    return totals, incZs, notZs

def I_mn(m, n, a, b):
    if m < 0 or n < 0:
        return 0.0
    nodes, weights = np.polynomial.legendre.leggauss(64)
    gl_x = 0.5*(nodes + 1.0)
    gl_w = 0.5*weights
    u = gl_x
    with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
        term1 = np.exp(np.where((1 - a*u) > 0, m*np.log1p(-a*u), -np.inf))
        term2 = np.exp(np.where((1 - b*u) > 0, n*np.log1p(-b*u), -np.inf))
        vals = term1 * term2
    return float(np.sum(gl_w * vals))

# --- rho broadcasting utility ---
def _as_level_list(val: Union[float, Sequence[float]], L: int) -> List[float]:
    if isinstance(val, (list, tuple, np.ndarray)):
        arr = list(val)
        if len(arr) != L:
            raise ValueError(f"rho list must have length {L}, got {len(arr)}")
        return [float(x) for x in arr]
    return [float(val)] * L

# --- core structures ---
@dataclass
class LevelCounts:
    S_Z:  List[float]
    S_nZ: List[float]
    U_Z:  List[float]
    U_nZ: List[float]
    S:    List[float]
    U:    List[float]

def counts_for_truth_relaxed(M: int, k: int, collider_truth: bool,
                             rho_Z: Union[float, Sequence[float]] = 1.0,
                             rho_nZ: Union[float, Sequence[float]] = 1.0) -> LevelCounts:
    """
    Relaxed bucket-wise model:
      collider_truth=True  -> sepsets live in non-Z bucket with fraction rho_nZ[ell]
      collider_truth=False -> sepsets live in Z-including bucket with fraction rho_Z[ell]
    Remaining candidates in that bucket are non-sepsets. The opposite bucket is all non-sepsets.
    """
    totals, incZs, notZs = subset_counts_by_size(M, k)
    L = len(totals)
    rhoZ  = _as_level_list(rho_Z,  L)
    rhonZ = _as_level_list(rho_nZ, L)
    incZs = [float(x) for x in incZs]
    notZs = [float(x) for x in notZs]
    if collider_truth:
        S_nZ = [rhonZ[i]*notZs[i] for i in range(L)]
        U_nZ = [(1.0-rhonZ[i])*notZs[i] for i in range(L)]
        S_Z  = [0.0]*L
        U_Z  = incZs[:]  # all Z-including sets are non-sepsets
        S    = S_nZ[:]
        U    = [U_Z[i]+U_nZ[i] for i in range(L)]
    else:
        S_Z  = [rhoZ[i]*incZs[i] for i in range(L)]
        U_Z  = [(1.0-rhoZ[i])*incZs[i] for i in range(L)]
        S_nZ = [0.0]*L
        U_nZ = notZs[:]  # all non-Z sets are non-sepsets
        S    = S_Z[:]
        U    = [U_Z[i]+U_nZ[i] for i in range(L)]
    return LevelCounts(S_Z, S_nZ, U_Z, U_nZ, S, U)

def prev_no_hit_prefix(S: List[float], U: List[float], pA: float, pB: float) -> List[float]:
    out = []
    prod = 1.0
    for ell in range(len(S)):
        out.append(prod)
        prod *= (pA**S[ell]) * (pB**U[ell])
    return out

def pr_D(S: List[float], U: List[float], pA: float, pB: float) -> float:
    prod = 1.0
    for s,u in zip(S,U):
        prod *= (pA**s) * (pB**u)
    return 1.0 - prod

def pr_CPC_collider_givenD(levels: LevelCounts, a: float, b: float, pA: float, pB: float) -> float:
    S_Z, S_nZ, U_Z, U_nZ, S, U = levels.S_Z, levels.S_nZ, levels.U_Z, levels.U_nZ, levels.S, levels.U
    prev = prev_no_hit_prefix(S, U, pA, pB)
    PD = pr_D(S, U, pA, pB)
    if PD == 0: 
        return 0.0
    s = 0.0
    for ell in range(len(S)):
        no_Z_hits   = (pA**S_Z[ell])*(pB**U_Z[ell])
        some_nonZ   = 1.0 - (pA**S_nZ[ell])*(pB**U_nZ[ell])
        s += prev[ell] * no_Z_hits * some_nonZ
    return s / PD

def pr_CPC_nonc_givenD(levels: LevelCounts, a: float, b: float, pA: float, pB: float) -> float:
    S_Z, S_nZ, U_Z, U_nZ, S, U = levels.S_Z, levels.S_nZ, levels.U_Z, levels.U_nZ, levels.S, levels.U
    prev = prev_no_hit_prefix(S, U, pA, pB)
    PD = pr_D(S, U, pA, pB)
    if PD == 0: 
        return 0.0
    s = 0.0
    for ell in range(len(S)):
        no_nonZ_hits = (pA**S_nZ[ell])*(pB**U_nZ[ell])
        some_Z       = 1.0 - (pA**S_Z[ell])*(pB**U_Z[ell])
        s += prev[ell] * no_nonZ_hits * some_Z
    return s / PD

def pr_PC_collider_givenD(levels: LevelCounts, a: float, b: float, pA: float, pB: float) -> float:
    S_Z, S_nZ, U_Z, U_nZ, S, U = levels.S_Z, levels.S_nZ, levels.U_Z, levels.U_nZ, levels.S, levels.U
    prev = prev_no_hit_prefix(S, U, pA, pB)
    PD = pr_D(S, U, pA, pB)
    if PD == 0: 
        return 0.0
    s = 0.0
    for ell in range(len(S)):
        term1 = S_nZ[ell] * a * I_mn(S[ell]-1, U[ell], a, b) if S_nZ[ell] > 0 else 0.0
        term2 = U_nZ[ell] * b * I_mn(S[ell], U[ell]-1, a, b) if U_nZ[ell] > 0 else 0.0
        s += prev[ell] * (term1 + term2)
    return s / PD

def pr_PC_nonc_givenD(levels: LevelCounts, a: float, b: float, pA: float, pB: float) -> float:
    S_Z, S_nZ, U_Z, U_nZ, S, U = levels.S_Z, levels.S_nZ, levels.U_Z, levels.U_nZ, levels.S, levels.U
    prev = prev_no_hit_prefix(S, U, pA, pB)
    PD = pr_D(S, U, pA, pB)
    if PD == 0: 
        return 0.0
    s = 0.0
    for ell in range(len(S)):
        term1 = S_Z[ell] * a * I_mn(S[ell]-1, U[ell], a, b) if S_Z[ell] > 0 else 0.0
        term2 = U_Z[ell] * b * I_mn(S[ell], U[ell]-1, a, b) if U_Z[ell] > 0 else 0.0
        s += prev[ell] * (term1 + term2)
    return s / PD

# --- metrics wrapper with rho ---
def step1_metrics_for_network_relaxed(N_nodes: int, k: int, alpha: float, beta: float,
                                      rho_Z: Union[float, Sequence[float]] = 1.0,
                                      rho_nZ: Union[float, Sequence[float]] = 1.0) -> Dict[str, Dict[str, float]]:
    a = 1.0 - beta
    b = alpha
    pA, pB = 1.0 - a, 1.0 - b
    M = max(0, N_nodes - 2)
    L_coll = counts_for_truth_relaxed(M, k, collider_truth=True,  rho_Z=rho_Z, rho_nZ=rho_nZ)
    L_nonc = counts_for_truth_relaxed(M, k, collider_truth=False, rho_Z=rho_Z, rho_nZ=rho_nZ)
    # CPC
    c_collider = pr_CPC_collider_givenD(L_coll, a, b, pA, pB)
    c_collider_cross = pr_CPC_collider_givenD(L_nonc, a, b, pA, pB)
    c_nonc = pr_CPC_nonc_givenD(L_nonc, a, b, pA, pB)
    c_nonc_cross = pr_CPC_nonc_givenD(L_coll, a, b, pA, pB)
    # PC
    p_collider = pr_PC_collider_givenD(L_coll, a, b, pA, pB)
    p_collider_cross = pr_PC_collider_givenD(L_nonc, a, b, pA, pB)
    p_nonc = pr_PC_nonc_givenD(L_nonc, a, b, pA, pB)
    p_nonc_cross = pr_PC_nonc_givenD(L_coll, a, b, pA, pB)
    return {
        "CPC_colliders_first": {"TP": c_collider, "FP": c_collider_cross, "FN": 1.0 - c_collider},
        "CPC_nonc_first":      {"TP": c_nonc,     "FP": c_nonc_cross,     "FN": 1.0 - c_nonc},
        "PC_colliders_first":  {"TP": p_collider, "FP": p_collider_cross, "FN": 1.0 - p_collider},
        "PC_nonc_first":       {"TP": p_nonc,     "FP": p_nonc_cross,     "FN": 1.0 - p_nonc},
    }
    
def build_rows(bn_metrics: Dict[str, Dict], alpha, beta, k) -> List[Dict[str, float]]:
    rows = []
    for name, meta in bn_metrics.items():
        N = meta["nodes"]
        stats = step1_metrics_for_network_relaxed(N, k, alpha, beta)
        row = {
            'network': name,
            'PC_colliders_first_TP':  stats['PC_colliders_first']['TP'],
            'PC_nonc_first_TP':       stats['PC_nonc_first']['TP'],
            'PC_colliders_first_FP':  stats['PC_colliders_first']['FP'],
            'PC_nonc_first_FP':       stats['PC_nonc_first']['FP'],
            'PC_colliders_first_FN':  stats['PC_colliders_first']['FN'],
            'PC_nonc_first_FN':       stats['PC_nonc_first']['FN'],
            'CPC_colliders_first_TP': stats['CPC_colliders_first']['TP'],
            'CPC_nonc_first_TP':      stats['CPC_nonc_first']['TP'],
            'CPC_colliders_first_FP': stats['CPC_colliders_first']['FP'],
            'CPC_nonc_first_FP':      stats['CPC_nonc_first']['FP'],
            'CPC_nonc_first_FN':      stats['CPC_nonc_first']['FN'],
            'CPC_colliders_first_FN': stats['CPC_colliders_first']['FN'],
        }
        rows.append(row)
    return rows

rows = build_rows(bn_metrics, alpha, beta, k)
df = pd.DataFrame(rows, columns=[
    'network',
    'PC_colliders_first_TP','PC_nonc_first_TP','PC_colliders_first_FP','PC_nonc_first_FP','PC_colliders_first_FN','PC_nonc_first_FN',
    'CPC_colliders_first_TP','CPC_nonc_first_TP','CPC_colliders_first_FP','CPC_nonc_first_FP','CPC_nonc_first_FN','CPC_colliders_first_FN'
])
df[['network', 'PC_colliders_first_FP', 'PC_nonc_first_FP', 'CPC_colliders_first_FP','CPC_nonc_first_FP']]
