import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import xgboost as xgb
from typing import List, Optional, Dict, Union

# ==========================================
# 1. 基础工具 (Utils)
# ==========================================

def perturb(tensor: torch.Tensor, strength: float = 1.5):
    """
    极强扰动初始化，用于测试 ArbNN 是否具备自修复能力
    """
    scale = tensor.std().clamp(min=1e-6)
    noise = torch.randn_like(tensor) * scale * strength
    return tensor * 0.1 + noise


def tree_parser_for_compilation(tree_dump: Dict, input_dim: int) -> Optional[Dict]:
    """
    解析 XGBoost JSON 树结构，生成 ArbNN 所需的矩阵 W, B, V, P。
    """
    leaf_nodes, leaf_vals = {}, {}
    internal_nodes, internal_feats, internal_thresh = {}, {}, {}
    leaf_paths = {} # Store path logic: ([node_indices], [directions])
    
    leaf_cnt, internal_cnt = 0, 0

    def dfs(node, path_nodes, path_dirs):
        nonlocal leaf_cnt, internal_cnt
        node_id = str(node['nodeid'])
        
        if 'leaf' in node:
            # 处理叶节点
            if node_id not in leaf_nodes:
                leaf_nodes[node_id] = leaf_cnt
                leaf_vals[leaf_cnt] = node['leaf']
                leaf_cnt += 1
            leaf_paths[node_id] = (path_nodes, path_dirs)
        else:
            # 处理内部节点
            if node_id not in internal_nodes:
                internal_nodes[node_id] = internal_cnt
                # 解析 split feature (e.g., "f7" -> 7)
                f_str = str(node['split'])
                f_idx = int(f_str[1:]) if f_str.startswith('f') else int(f_str)
                internal_feats[internal_cnt] = f_idx
                internal_thresh[internal_cnt] = node['split_condition']
                internal_cnt += 1
            
            curr_idx = internal_nodes[node_id]
            
            if 'children' in node:
                for child in node['children']:
                    # XGBoost JSON: 'yes' ID 对应左子树 (direction < split)
                    is_left = (child['nodeid'] == node['yes']) 
                    # ArbNN logic: Tanh output is [-1, 1]. 
                    # Usually left is negative direction, right is positive.
                    # We map: Left -> -1.0, Right -> 1.0
                    direction = -1.0 if is_left else 1.0
                    dfs(child, path_nodes + [curr_idx], path_dirs + [direction])

    dfs(tree_dump, [], [])

    n_splits, n_leaves = len(internal_nodes), len(leaf_nodes)
    if n_splits == 0: return None # 忽略只有叶子的树（通常是初始偏置）

    # 初始化矩阵
    W = np.zeros((n_splits, input_dim))
    B = np.zeros(n_splits)
    P_bin = np.zeros((n_splits, n_leaves)) 
    P_w = np.zeros((n_splits, n_leaves)) 
    L_depth = np.zeros(n_leaves)
    V = np.zeros(n_leaves)

    # 填充 W 和 B
    for idx, f_idx in internal_feats.items():
        W[idx, f_idx] = 1.0
        B[idx] = internal_thresh[idx]

    # 填充 P 和 V
    for l_id, l_idx in leaf_nodes.items():
        nodes, dirs = leaf_paths[l_id]
        L_depth[l_idx] = len(nodes)
        V[l_idx] = leaf_vals[l_idx]
        
        for depth, (n_idx, d_val) in enumerate(zip(nodes, dirs)):
            # ArbNN 使用指数衰减权重 (1/2)^(depth - 1)
            weight = 0.5 ** depth
            P_w[n_idx, l_idx] = d_val * weight
            P_bin[n_idx, l_idx] = d_val

    return {
        'W': W, 'B': B, 'V': V,
        'P_bin': P_bin, 'P_w': P_w, 'L': L_depth,
        'dims': (input_dim, n_splits, n_leaves),
        'nodeid2Index': {'interNode': internal_nodes, 'leafNode': leaf_nodes}
    }


class BaseCompiledTree(nn.Module):
    """
    ArbNN 的基类，处理元数据加载的安全性问题。
    """
    def __init__(self, matrices):
        super().__init__()
        # 修复：安全地获取 nodeid2Index，如果不存在（从 state_dict 加载时）则设为 None
        node_map = matrices.get('nodeid2Index', None)
        if node_map:
            leaf_nid2idx = node_map['leafNode']
            self.leaf_index_to_nid = {idx: nid for nid, idx in leaf_nid2idx.items()}
        else:
            self.leaf_index_to_nid = None

    def compute_raw_splits(self, x: torch.Tensor) -> torch.Tensor:
        # 投影计算: H = WX - f [cite: 430]
        return x @ self.split_weights.t() - self.split_bias

    @torch.no_grad()
    def _indices_to_nids(self, leaf_indices: torch.Tensor, return_index: bool = False) -> Union[List[str], torch.Tensor]:
        """
        辅助方法：将叶子索引转换为原始 XGBoost nodeid。
        """
        if return_index:
            return leaf_indices
        
        if self.leaf_index_to_nid is None:
            # 如果丢失了元数据（例如仅加载了权重），只能返回索引并警告
            # print("Warning: nodeid metadata missing, returning indices.")
            return leaf_indices.cpu().numpy().tolist()

        # 转换逻辑
        indices_list = leaf_indices.cpu().numpy().tolist()
        return [self.leaf_index_to_nid.get(idx, str(idx)) for idx in indices_list]

class ArbNN(BaseCompiledTree):
    """
    [Type 1] Arboreal Neural Network
    论文来源: Arboreal Neural Network [cite: 3]
    特点: 稀疏特征选择 + 矩阵化软路由 + 端到端可微
    """
    def __init__(self, matrices):
        super().__init__(matrices)

        B = torch.tensor(matrices['B'], dtype=torch.float32)
        V = torch.tensor(matrices['V'], dtype=torch.float32)
        
        self.split_bias = nn.Parameter(perturb(B))
        self.leaf_vals  = nn.Parameter(perturb(V))

        self.split_weights = nn.Parameter(torch.tensor(matrices['W'], dtype=torch.float32))
        
        self.P = nn.Parameter(torch.tensor(matrices['P_w'], dtype=torch.float32), requires_grad=False)
        
        self.tau1 = 100.0 # 用于 Tanh 
        self.tau2 = 100.0  # 用于 Softmax 

    def forward(self, x):
        # 1. Hyperplane Projection: H = WX - f
        # 2. Decision Vector: d = tanh(tau1 * H)
        splits = self.compute_raw_splits(x)
        d = torch.tanh(self.tau1 * splits)
        
        # 3. Subspace Affinity: M = P^T d
        # 注意: d shape [B, n_splits], P shape [n_splits, n_leaves] -> d @ P = [B, n_leaves]
        M = d @ self.P
        
        # 4. Leaf Routing: alpha = Softmax(tau2 * M) 
        alpha = torch.softmax(self.tau2 * M, dim=1)
        
        # 5. Aggregation: y = alpha^T v
        return alpha @ self.leaf_vals

    @torch.no_grad()
    def pred_leaf(self, x: torch.Tensor, return_index: bool = False):
        """
        预测样本落入的叶子节点 ID。
        逻辑：取 Softmax 概率最大的叶子。
        """
        splits = self.compute_raw_splits(x)
        d = torch.tanh(self.tau1 * splits)
        M = d @ self.P
        # 直接对 Logits (M) 做 argmax 即可，不需要做 Softmax
        leaf_indices = M.argmax(dim=1)
        return self._indices_to_nids(leaf_indices, return_index)


# ==========================================
# 3. 森林容器
# ==========================================

class CompiledForest(nn.Module):
    def __init__(self, tree_cells: List[nn.Module]):
        super().__init__()
        self.trees = nn.ModuleList(tree_cells)

    def forward(self, x):
        if len(self.trees) == 0:
            return torch.zeros(x.shape[0], 1, device=x.device)
        return torch.stack([tree(x) for tree in self.trees], dim=0).sum(dim=0)

    def pred_leaf(self, x: torch.Tensor, return_index: bool = False) -> Union[List[List[str]], torch.Tensor]:
        """获取森林中每棵树的预测叶子 ID。"""
        leaves = [tree.pred_leaf(x, return_index) for tree in self.trees]
        
        if return_index:
            return torch.stack(leaves, dim=1)
        else:
            # 转置 List: [num_trees, B] -> [B, num_trees]
            batch_size = len(leaves[0])
            transposed = [[leaves[t][i] for t in range(len(self.trees))] for i in range(batch_size)]
            return transposed

    @classmethod
    def from_xgboost_model(cls, model_path: str, input_dim: int, mode: str = 'arbnn'):
        """
        工厂方法：从 XGBoost JSON 模型文件初始化编译后的森林。
        """
        if xgb is None:
            raise ImportError("XGBoost library is required to use from_xgboost_model.")

        print(f"Attempting to compile {mode.upper()} Forest from {model_path}...")

        try:
            # 1. 加载 XGBoost 模型
            bst = xgb.Booster()
            bst.load_model(model_path)
            # 2. 获取 JSON dump
            model_dump = bst.get_dump(dump_format='json')
        except Exception as e:
            print(f"Error loading XGBoost model: {e}")
            return cls([])

        parsed_trees = [json.loads(tree) for tree in model_dump]
        tree_cells = []
        
        # 决定实例化哪种树
        TreeClass = ArbNN if mode.lower() == 'arbnn' else NRF
        
        # 3. 逐树解析并实例化
        for i, tree_json in enumerate(parsed_trees):
            # 调用外部定义的解析器 (假设 tree_parser_for_compilation 在作用域内)
            matrices = tree_parser_for_compilation(tree_json, input_dim)
            
            if matrices is None:
                print(f"Skipping tree {i} (empty or bias-only).")
                continue
                
            tree_cells.append(TreeClass(matrices))
            
        print(f"Successfully compiled {len(tree_cells)} trees.")
        return cls(tree_cells)

    @classmethod
    def load_state_dict_from_file(cls, state_dict, mode: str = "arbnn"):
        """
        修复后的加载器：能够处理缺少元数据的情况，并正确初始化特定类型的树。
        """
        instance = cls([])
        
        # 1. 识别树的数量
        tree_indices = sorted({
            int(key.split('.')[1]) 
            for key in state_dict.keys() if key.startswith('trees.')
        })
        
        TreeClass = ArbNN if mode.lower() == 'arbnn' else NRF
        
        for index in tree_indices:
            prefix = f"trees.{index}."
            
            # 2. 提取参数形状以初始化 Dummy Matrices
            W = state_dict[prefix + "split_weights"]
            n_splits, input_dim = W.shape
            n_leaves = state_dict[prefix + "leaf_vals"].shape[0]
            
            # 3. 构建最小初始化字典
            matrices = {
                'W': np.zeros((n_splits, input_dim)),
                'B': np.zeros(n_splits),
                'V': np.zeros(n_leaves),
                # 显式不提供 'nodeid2Index'，避免 KeyError
            }
            
            if mode.lower() == 'arbnn':
                matrices['P_w'] = np.zeros((n_splits, n_leaves))
            else:
                matrices['P_bin'] = np.zeros((n_splits, n_leaves))
                matrices['L'] = np.zeros(n_leaves)

            # 4. 实例化并加载参数
            tree = TreeClass(matrices)
            
            # 过滤当前树的参数
            tree_state = {
                k[len(prefix):]: v 
                for k, v in state_dict.items() if k.startswith(prefix)
            }
            tree.load_state_dict(tree_state, strict=False) # strict=False 忽略缓冲区差异
            instance.trees.append(tree)
            
        return instance

# ==========================================
# 4. 演示
# ==========================================
if __name__ == '__main__':
    
    XGB_MODEL_PATH = "demo_xgb_model.bin"
    INPUT_DIM = 20
    BATCH_SIZE = 256
    
    # 1. 从 XGBoost 初始化 ArbNN 森林
    arb_forest = CompiledForest.from_xgboost_model(
        model_path=XGB_MODEL_PATH,
        input_dim=INPUT_DIM,
        mode="arbnn"
    )
        
    # 2. 前向传播与预测叶节点
    x_input = torch.randn(BATCH_SIZE, INPUT_DIM)
    y_pred = arb_forest(x_input)
    leaf_ids = arb_forest.pred_leaf(x_input)