import math, torch
from typing import Callable, Tuple

# --- 工具：把 numpy/tensor 统一成 1D torch.Tensor 放到 device 上 ---
def _to_1d_tensor(x, device, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        t = x.detach().to(device=device, dtype=dtype).view(-1)
    else:
        t = torch.as_tensor(x, device=device, dtype=dtype).view(-1)
    return t

# --- Shapley 采样的 |U| 分布 ρ(s) ∝ 1 / (s * (n - s))（跳过 0 和 n）---
def _rho_sizes(n: int):
    s = torch.arange(1, n, dtype=torch.float64)  # 1..n-1
    rho = 1.0 / (s * (n - s))
    rho = rho / rho.sum()
    return s.to(torch.int64), rho.to(torch.float64)

# --- 固定可复现的“按样本下标”掩码生成（不存整个 mask，按需重建）---
def _make_masks(n: int, sizes_this_batch: torch.LongTensor, base_seed: int, start_idx: int, device):
    """
    返回 mask ∈ {0,1}^{bs×n}（float16 存储，减少显存）；每行恰好有 size 个 1
    通过 (base_seed + global_sample_idx) 保证跨多次遍历一致
    """
    bs = sizes_this_batch.numel()
    masks = torch.zeros(bs, n, device=device, dtype=torch.float16)  # 低精度存储
    for r in range(bs):
        g = torch.Generator(device=device)
        g.manual_seed(int(base_seed + start_idx + r))
        s = int(sizes_this_batch[r].item())
        # randperm(n)[:s] 在 GPU 上可不放回采样 s 个索引
        idx = torch.randperm(n, generator=g, device=device)[:s]
        masks[r, idx] = 1.0
    return masks

# --- 共轭梯度（用隐式算子 A·v，其中 A = X^T W X + αI；不显式存 X）---
def _cg_solve(matvec: Callable[[torch.Tensor], torch.Tensor],
              b: torch.Tensor,
              max_iter: int = 50,
              tol: float = 1e-5):
    x = torch.zeros_like(b)
    r = b.clone()
    p = r.clone()
    rs_old = torch.dot(r, r)
    for _ in range(max_iter):
        Ap = matvec(p)
        denom = torch.dot(p, Ap) + 1e-12
        alpha = rs_old / denom
        x = x + alpha * p
        r = r - alpha * Ap
        rs_new = torch.dot(r, r)
        if torch.sqrt(rs_new) < tol:
            break
        p = r + (rs_new / (rs_old + 1e-12)) * p
        rs_old = rs_new
    return x

@torch.no_grad()
def sampling_shap(
    predict_fn: Callable[[torch.Tensor], torch.Tensor],  # (B,n)->(B,)
    x,                       # 1D: np.ndarray 或 torch.Tensor
    baseline,                # 1D: 同上
    *,
    k: int | None = None,    # 总采样数（建议 0.05~0.1·n 起步）
    reg_alpha: float = 1e-3, # 岭回归系数
    seed: int | None = None,
    device: torch.device | str = "cuda",
    batch_size: int = 64,    # 视显存调，64/128 通常可行（60000 维）
    cg_max_iter: int = 30,
    cg_tol: float = 1e-5,
) -> Tuple[torch.Tensor, float]:
    """
    返回:
      phi:  torch.Tensor, shape=(n,), 在 device 上
      intercept: float
    说明:
      设计矩阵 X 的每行 = [1, mask]，不显式存；用两遍流式遍历：
        1) 计算 g = X^T W y（需要模型前向一次，得到 y）
        2) CG 里多次用隐式算子 A·v = X^T W (X v) + α v
    """
    device = torch.device(device)
    x_t = _to_1d_tensor(x, device, dtype=torch.float32)
    b_t = _to_1d_tensor(baseline, device, dtype=torch.float32)
    assert x_t.shape == b_t.shape
    n = x_t.numel()
    if k is None:
        k = 2 * n  # 兼容老接口；但在 60k 维上建议改小到 0.05~0.1·n

    # 1) 为 k 个样本分配 |U|：按 ρ(s) 抽样
    sizes_vals, rho = _rho_sizes(n)            # s ∈ [1..n-1]
    gen_cpu = torch.Generator()
    if seed is not None: gen_cpu.manual_seed(int(seed))
    sizes_all = sizes_vals[torch.multinomial(rho, num_samples=k, replacement=True, generator=gen_cpu)]
    # 注意：W（权重）这里取 1.0；你若需要 kernel SHAP 的精确权重，可自行按 |U| 计算
    # w_i = w(|U_i|)；此处统一取 1.0，简单稳定
    base_seed = 17_171 if seed is None else int(seed)

    # 2) 先算 g = X^T W y（一次完整前向）
    p = n + 1  # 截距 + n 个特征
    g = torch.zeros(p, device=device, dtype=torch.float32)
    # 为了减少数值误差，预测时用 float32；mask/输入可用 float16
    start = 0
    while start < k:
        end = min(start + batch_size, k)
        sizes_bs = sizes_all[start:end].to(device)

        # 生成本批次的 masks（float16 存储）
        masks = _make_masks(n, sizes_bs, base_seed=base_seed, start_idx=start, device=device)  # (bs,n), {0,1}
        # 构造联盟样本：xb = mask * x + (1-mask) * b
        # 用 float16 临时存储，预测前转为 float32 以更稳
        xb = masks * (x_t - b_t).to(dtype=torch.float16) + b_t.to(dtype=torch.float16)
        y = predict_fn(xb.to(dtype=torch.float32))    # (bs,)
        y = y.reshape(-1).to(dtype=torch.float32)

        # 累加 g = X^T y；X = [1, mask]
        # g0 += sum(y)
        g[0] += y.sum()
        # g_feat += mask^T y
        # 先把 mask 转 float32 再右乘 y（广播），最后按行求和
        g[1:] += (masks.to(torch.float32).T @ y)

        start = end
        del masks, xb, y
        # 不强制 empty_cache，交给 PyTorch 内存管理

    # 3) 定义隐式算子 A·v = X^T (X v) + α v（此处 W=I；若有 W，可加权）
    def matvec(v: torch.Tensor) -> torch.Tensor:
        """
        v shape = (p,)；返回 A v，同样 shape
        """
        v0 = v[0]
        vf = v[1:]                 # (n,)
        out0 = torch.zeros((), device=device, dtype=torch.float32)
        outf = torch.zeros(n, device=device, dtype=torch.float32)

        start2 = 0
        while start2 < k:
            end2 = min(start2 + batch_size, k)
            sizes_bs = sizes_all[start2:end2].to(device)
            masks = _make_masks(n, sizes_bs, base_seed=base_seed, start_idx=start2, device=device)  # (bs,n)

            # s = X v = v0*1 + mask @ vf
            # 注意把 mask 转 float32 再 matmul
            s = v0 + (masks.to(torch.float32) @ vf)   # (bs,)
            # X^T s ：[ sum(s), mask^T s ]
            out0 += s.sum()
            outf += (masks.to(torch.float32).T @ s)

            start2 = end2
            del masks, s

        out = torch.empty_like(v, dtype=torch.float32)
        out[0] = out0 + reg_alpha * v0
        out[1:] = outf + reg_alpha * vf
        return out

    # 4) 用 CG 解 (X^T X + αI) β = g
    beta = _cg_solve(matvec, g, max_iter=cg_max_iter, tol=cg_tol)  # (p,)
    intercept = float(beta[0].item())
    phi = beta[1:]  # (n,)

    return phi, intercept

