# ------------------------------------------------------------
# path_margin_soft.py  (with built‑in self‑test)
# ------------------------------------------------------------
import numpy as np
from typing import Dict, List, Tuple
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
class FrozenDepthTree:
    """
    a light wrapper around DecisionTreeRegressor that
    keeps the first two levels (<=4 leaves) unchanged once frozen.
    """
    def __init__(self, max_depth=6,  min_samples_leaf=2000,
                 min_samples_split=2000, random_state=0):
        self.frozen_rules = None
        self.params = dict(max_depth=max_depth,
                           min_samples_leaf=min_samples_leaf,
                           min_samples_split=min_samples_split,
                           random_state=random_state)

    def _apply_frozen(self, X):
        """
        返回 X 被前两层规则分到的 bin 编号 0-3
        """
        f1, thr1, f2, thr2 = self.frozen_rules
        left  = X[:, f1] <= thr1
        bin_0 = left & (X[:, f2] <= thr2)
        bin_1 = left & (~bin_0)
        bin_2 = (~left) & (X[:, f2] <= thr2)
        bin_3 = (~left) & (~bin_2)
        idx   = np.full(len(X), 3)
        idx[bin_0] = 0; idx[bin_1] = 1; idx[bin_2] = 2
        return idx.astype(int)

    def fit(self, X, y):
        if self.frozen_rules is None:
            # ▶ 先拟合一棵深树，抽出前两层 split
            tmp = DecisionTreeRegressor(**self.params).fit(X, y)
            tree = tmp.tree_
            # 根节点 & 它的两个子节点
            f1, thr1 = tree.feature[0], tree.threshold[0]
            left_idx = tree.children_left[0]
            # 若左子结点非叶，再取其 split；否则用右子结点的
            if tree.children_left[left_idx] != -1:
                f2, thr2 = tree.feature[left_idx], tree.threshold[left_idx]
            else:
                right_idx = tree.children_right[0]
                f2, thr2 = tree.feature[right_idx], tree.threshold[right_idx]
            self.frozen_rules = (f1, thr1, f2, thr2)

        # ▶ 重新组装数据：先用 frozen bin 作第一/二层，再对每个 bin 细分
        bin_id = self._apply_frozen(X)
        self.sub_models = {}
        for b in range(4):
            mask = bin_id == b
            if mask.sum() < self.params['min_samples_leaf']:    # 数据太少直接记录均值
                const = float(np.nan_to_num(y, nan=np.nanmean(y)).mean())
                self.sub_models[b] = const
            else:
                sub = DecisionTreeRegressor(max_depth=self.params['max_depth']-2,
                                            **{k:v for k,v in self.params.items()
                                               if k not in ['max_depth']})
                sub.fit(X[mask], y[mask])
                self.sub_models[b] = sub
        return self

    def predict(self, X):
        bin_id = self._apply_frozen(X)
        out = np.zeros(len(X))
        for b in range(4):
            mask = bin_id == b
            if mask.sum() == 0:             # ← 新增：该 bin 本批次没人
                continue
            model_b = self.sub_models[b]
            if isinstance(model_b, float):
                out[mask] = model_b
            else:
                out[mask] = model_b.predict(X[mask])   # 此时至少 1 sample，安全
        return out

    # 仍需要 .apply 方法来给 soft-label/leaf-id
    def apply(self, X):
        bin_id = self._apply_frozen(X)
        leaf = []
        for idx, b in enumerate(bin_id):
            model_b = self.sub_models[b]
            if isinstance(model_b, float):
                leaf.append(b*10_000)            # 给常数模型一个虚拟叶编号
            else:
                leaf.append(b*10_000 + model_b.apply(X[idx:idx+1])[0])
        return np.array(leaf)
        
    def get_paths(self) -> Dict[int, List[Tuple[int, float, str]]]:
        """
        返回 leaf_id → [ (feat, thr, '<='|'>'), … ]
        leaf_id 与 .apply 保持一致：bin*10000 + sub_leaf_id
        """
        if self.frozen_rules is None:
            raise RuntimeError("Tree is not fitted yet!")
    
        f1, thr1, f2, thr2 = self.frozen_rules
        prefix = {                      # 四个 bin 的根->叶条件
            0: [(f1, thr1, "<="), (f2, thr2, "<=")],
            1: [(f1, thr1, "<="), (f2, thr2,  ">")],
            2: [(f1, thr1,  ">"), (f2, thr2, "<=")],
            3: [(f1, thr1,  ">"), (f2, thr2,  ">")]
        }
    
        all_paths = {}
        for b in range(4):
            model_b = self.sub_models[b]
            if isinstance(model_b, float):             # 常数叶
                leaf_id = b * 10_000                   # 与 .apply 对齐
                all_paths[leaf_id] = prefix[b]
            else:
                # 子树自身的路径；extract_paths 会自动剥壳到 .tree_
                sub_paths = extract_paths(model_b)
                for sub_leaf, conds in sub_paths.items():
                    leaf_id = b * 10_000 + sub_leaf
                    all_paths[leaf_id] = prefix[b] + conds
        return all_paths

# ---------- 1) 预提取每个叶节点的路径约束 ----------

def extract_paths(tree) -> Dict[int, List[Tuple[int, float, str]]]:
    """将 sklearn 决策树的所有叶节点路径解析为:
      leaf_id -> [(feat_index, threshold, '<='|'>'), ...]"""
    if hasattr(tree, "get_paths"):
        return tree.get_paths()          # 直接向对象要即可

    # ② 下面保持原逻辑，剥 sklearn 的 .tree_
    while not hasattr(tree, "children_left"):
        if hasattr(tree, "tree_"):
            tree = tree.tree_
        elif hasattr(tree, "reg_"):
            tree = tree.reg_.tree_
        elif hasattr(tree, "inner_"):
            tree = tree.inner_.tree_
        else:
            raise AttributeError("Cannot locate .tree_ inside the estimator")
    children_left  = tree.children_left
    children_right = tree.children_right
    feature        = tree.feature
    threshold      = tree.threshold

    paths: Dict[int, List[Tuple[int, float, str]]] = {}
    stack = [(0, [])]  # (node_id, 当前路径 list)

    while stack:
        node_id, conds = stack.pop()

        # 判断是否叶子
        if children_left[node_id] == children_right[node_id]:
            paths[node_id] = conds
            continue

        feat = feature[node_id]
        thr  = threshold[node_id]

        # 左子树：x_f <= thr
        stack.append((children_left[node_id],
                      conds + [(feat, thr, "<=")] ))
        # 右子树：x_f > thr
        stack.append((children_right[node_id],
                      conds + [(feat, thr, ">")] ))

    return paths


# ---------- 2) 计算批量样本的 soft label ----------

def path_margin_softlabel(
    dt_model,
    X: np.ndarray,
    tau: float = 5.0,
    eps=1e-8
) -> np.ndarray:
    """返回 soft label 矩阵 P, shape = (B, n_leaf), 每行加和 = 1"""

    if isinstance(dt_model, FrozenDepthTree):
        paths = dt_model.get_paths()            # ← 新增
    else:
        paths = extract_paths(dt_model)         # 会自动 .tree_ 剥壳
    leaf_ids = list(paths.keys())              # 列顺序固定
    n_leaf = len(leaf_ids)
    B, _   = X.shape

    # 预存每叶的路径约束，避免循环里多次索引 dict
    leaf_conditions = [paths[lid] for lid in leaf_ids]

    P = np.zeros((B, n_leaf), dtype=np.float32)

    for i, x in enumerate(X):
        dists = np.zeros(n_leaf, dtype=np.float32)
        # —— 计算 x 到每个叶的路径距离 ——
        for j, conds in enumerate(leaf_conditions):
            d = 0.0
            for feat, thr, direction in conds:
                val = x[feat]
                if direction == "<=":
                    d += max(0.0, val - thr)   # 走错方向累加 margin
                else:  # direction == ">"
                    d += max(0.0, thr - val)
            dists[j] = d

        # Soft‑max → 概率
        w = np.exp(-tau * dists) + eps
        P[i] = w / w.sum() 

    return P  # shape (B, n_leaf)


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    from sklearn.datasets import load_iris
    from sklearn.tree import plot_tree, export_text
    import seaborn as sns

    # 数据 & 训练树
    X, y = load_iris(return_X_y=True)
    dt_model = DecisionTreeClassifier(max_depth=3, random_state=0).fit(X, y)
    leaf_ids = list(extract_paths(dt_model.tree_).keys())  # [1,4,5,7,8]
    print("leaf_ids", leaf_ids)
    # 1) 打印决策树文本结构
    print("\n=== Decision Tree structure (text) ===")
    print(export_text(dt_model, feature_names=[f"f{i}" for i in range(X.shape[1])]))

    # 2) 绘制树结构图
    plt.figure(figsize=(10, 6))
    plot_tree(dt_model, feature_names=[f"f{i}" for i in range(X.shape[1])],
              class_names=[str(c) for c in np.unique(y)], filled=True)
    plt.title("Decision Tree (iris)")
    plt.show()

    # 3) 样本→叶节点分布
    leaf_id_all = dt_model.apply(X)
    unique, counts = np.unique(leaf_id_all, return_counts=True)
    plt.figure(figsize=(6, 3))
    plt.bar(range(len(unique)), counts, tick_label=unique)
    plt.xlabel("leaf id"); plt.ylabel("#samples")
    plt.title("Sample distribution across leaf nodes")
    plt.show()

    # 4) 生成软标签
    tau = 2.0
    P = path_margin_softlabel(dt_model, X, tau=tau)
    print("Soft label shape:", P.shape)
    print("Row sums (≈1):", P.sum(axis=1)[:10])

    # 5) 验证 argmax
    leaf_order = list(extract_paths(dt_model.tree_).keys())
    pred_leaf = dt_model.apply(X)
    argmax_ok = (np.argmax(P, axis=1) == [leaf_order.index(l) for l in pred_leaf]).all()
    print("argmax equal?", argmax_ok)

    # 6) 热图（前 30 个样本）
    plt.figure(figsize=(8, 4))
    ax = sns.heatmap(P, cmap="viridis", cbar=True)
    ax.set_xticks(np.arange(len(leaf_ids)) + 0.5)
    ax.set_xticklabels(leaf_ids)     # 用真实 leaf id
    plt.title(f"Soft-label heatmap (tau={tau})")
    plt.xlabel("leaf index"); plt.ylabel("sample index")
    plt.show()

    # 7) 单样本距离 vs 概率对照
    sample_id = 51
    print(f"sample_id:{sample_id}, label: {y[sample_id]}")
    leaf_paths = extract_paths(dt_model.tree_)
    leaf_ids   = list(leaf_paths.keys())

    d_list = []
    for lid in leaf_ids:
        dist = 0.0
        for feat, thr, direction in leaf_paths[lid]:
            val = X[sample_id, feat]
            dist += max(0.0, val-thr) if direction=="<=" else max(0.0, thr-val)
        d_list.append(dist)
    print("\nLeaf | dist | prob")
    for lid, d, p in zip(leaf_ids, d_list, P[sample_id]):
        print(f"{lid:4d} | {d:6.3f} | {p:6.3f}")
    plt.bar(leaf_ids, P[sample_id])
    plt.figure(figsize=(6,3))
    plt.scatter(d_list, P[sample_id])
    plt.xlabel("distance d_L(x)"); plt.ylabel("probability")
    plt.title("Distance vs probability (sample 0)")
    plt.grid(True)
    plt.show()
