import numpy as np
import torch
import random
from typing import List, Dict, Tuple, Optional,ClassVar,Any,Set
from mcts.node import Node,Subspace
from data.bodata import BoData
from sampler.random import RandomSampler
from collections import defaultdict
from loguru import logger
class MCTS:
    def __init__(self, dataset:BoData  = None,
                 num_init_samples: int = 5,
                 leaf_size:int=100,
                 batch_size:int=14,#20,#10,
                 sampler: Optional[RandomSampler] = None,
                 cp = 1.0,
                 variable_nums:int=0,
                 n_candidates:int=10,
                 use_diverse_sample:bool=False,
                 )->None:

        self._num_init_samples = num_init_samples
        self._cp = cp
        self._dataset = dataset
        self._leaf_size = leaf_size
        self._batch_size = batch_size
        if sampler is None:
            self._sampler = RandomSampler(dataset)
        else:
            self._sampler = sampler
        self._root: Optional[Node] = None

        # 层数
        self.variable_nums: int = variable_nums
        # 初始轮重复组数
        self._n_candidates: int = n_candidates
        # 是否使用多样性采样
        self.use_diverse_sample: bool = use_diverse_sample

        # 统计访问次数
        self._leaf_visit_counts: Dict[int, int] = defaultdict(int)
    
    def analyze_combinations(self, order: List[str]) -> Dict[str, List[float]]:
        all_leaves = Node.get_all_leaves()
        combination_results = {}

        if not all_leaves:
            print("警告: 未找到任何叶子节点。树可能尚未构建。")
            return combination_results

        dataset = Subspace._shared_dataset

        for leaf in all_leaves:
            # 如果一个叶子节点是空的，则跳过
            if len(leaf.space.idxes) == 0:
                continue

            # --- 步骤 1: 通过回溯路径来确定组合名称 ---
            path_from_root = leaf.path_from_root()
            combo_parts = []
            
            # 路径从根开始，第一个子节点深度为1，对应order[0]
            # path_from_root[0] 是根节点，我们从 path_from_root[1] 开始
            for node in path_from_root[1:]:
                parent = node.parent
                component_name = order[node._depth - 1] # 节点深度从1开始，对应order索引0

                # 找到当前节点是其父节点的第几个孩子，以确定分组索引
                child_index = -1
                for i, child_id in enumerate(parent._children):
                    if child_id == node.id:
                        child_index = i
                        break
                
                combo_parts.append(f"{component_name}_Group{child_index}")
            
            combination_key = " | ".join(combo_parts)

            # --- 步骤 2: 提取该组合下所有的真实产率 ---
            # .item() 用于将PyTorch张量转换为Python纯数字
            yields = [dataset[idx].observed_value.item() for idx in leaf.space.idxes]
            
            combination_results[combination_key] = yields

        return combination_results
        
    def find_node_by_data_idx(self,data_idx:int)->Optional[Node]:
        search_node = self._root
        while not search_node.is_leaf:
            for child in search_node.children:
                if data_idx in child.space.idxes:
                    search_node = child
                    break
        # print(f"    search node id: {search_node.id}")
        return search_node
        
        
    def init_tree(self,expert_partition:Optional[Dict[str,List[List]]]=None,order:Optional[List[str]]=None,dont_build_tree:bool=False):
        Subspace.init(self._dataset)
        Node.init()
        self._root = Node(space = Subspace(self._dataset.all_idxes),leaf_size=self._leaf_size,depth=0)
        if dont_build_tree:
            pass
        elif expert_partition is None:
            Node.build_tree(self._root)
        else:
            Node.build_tree_by_expert(self._root,expert_partition,order)
        # print(f"The tree has {len(Node._all_nodes)} nodes and {len(Node._all_leaves)} leaves")
        random_idxes = torch.randperm(len(range(len(self._dataset.all_idxes))))[:self._num_init_samples].tolist()
        candidates = [self._dataset.all_idxes[idx] for idx in random_idxes]
        observed_values = [self._dataset[idx].observed_value for idx in candidates]
        self._sampler.update_train(candidates,torch.vstack(observed_values))
        self._sampler.init_model(self._sampler.train_x,self._sampler.train_y)
        # for candidate, value in zip(candidates, observed_values):
        #     self.find_node_by_data_idx(candidate).backprop(value)
        acq_scores = self._sampler.acq_score(self._root.space.idxes)

        for acq, idx in zip(acq_scores, self._root.space.idxes):
            self.find_node_by_data_idx(idx).backprop(acq)
        return observed_values
    
    def diverse_init_tree(self,expert_partition:Optional[Dict[str,List[List]]]=None,order:Optional[List[str]]=None):
        Subspace.init(self._dataset)
        Node.init()
        self._root = Node(space=Subspace(self._dataset.all_idxes), leaf_size=self._leaf_size, depth=0)

        if expert_partition is None:
            Node.build_tree(self._root)
        else:
            Node.build_tree_by_expert(self._root, expert_partition, order)

        init_leaves = self.diversity_aware_random_sample(n_sample=self._batch_size, n_candidates=self._n_candidates)
        candidates = []
        for leaf in init_leaves:
            idx = random.choice(leaf.space.idxes)
            candidates.append(idx)

        observed_values = [self._dataset[idx].observed_value for idx in candidates]
        self._sampler.update_train(candidates, torch.vstack(observed_values))
        self._sampler.init_model(self._sampler.train_x, self._sampler.train_y)

        for idx in self._dataset.all_idxes:
            leaf = self.find_node_by_data_idx(idx)
            alpha = 1 / (1 + len(leaf.space.idxes))
            self.find_node_by_data_idx(idx).pseudo_backprop(self._dataset[idx].predict_value.squeeze(),
                                                            alpha=alpha)
                      
        for candidate, value in zip(candidates, observed_values):
            node = self.find_node_by_data_idx(candidate)
            node.inc_incomplete_count()
            node.backprop(value.squeeze())
        return observed_values

    def real_exp(self,expert_partition:Dict[str,List[List]],order:List[str],mask:List[int]):
        Subspace.init(self._dataset)
        Node.init()
        self._root = Node(space = Subspace(self._dataset.all_idxes),leaf_size=self._leaf_size,depth=0)
        print(f"root: {self._root}")
        Node.build_tree_by_expert(self._root,expert_partition,order)
        print(f"The tree has {len(Node._all_nodes)} nodes and {len(Node._all_leaves)} leaves")
        
        observed_candidate = [x for x in self._dataset.all_idxes if self._dataset[x].is_observed]
        observed_values = [self._dataset[idx].observed_value for idx in observed_candidate]
        print(f"    observed_candidate: {observed_candidate}")
        print(f"    observed_values: {observed_values}")
        self._sampler.update_train(observed_candidate,torch.vstack(observed_values))
        self._sampler.init_model(self._sampler.train_x,self._sampler.train_y)
        for idx in self._dataset.all_idxes:
            leaf = self.find_node_by_data_idx(idx)
            alpha = 1 / (1 + len(leaf.space.idxes))
            #alpha = 1
            self.find_node_by_data_idx(idx).pseudo_backprop(self._dataset[idx].predict_value.squeeze(),
                                                            alpha=alpha)
        for idx in observed_candidate:
            self.find_node_by_data_idx(idx).inc_incomplete_count()
            self.find_node_by_data_idx(idx).backprop(self._dataset[idx].observed_value.squeeze())
        
        print("backup is completed")
        search_nodes = self.sample_path(self._batch_size)
        candidate = []
        eta=0.01
        beta = 0.01
        # print(f"    search_nodes: {search_nodes}")
        for i,search_node in enumerate(search_nodes):
            # print(f"    search node id: {search_node.id}")
            # print(f"    unobserved length: {len(search_node.space.unobserved_idxes)}")
            eta =  eta + 0.01*i
            if i>=self._batch_size//2:
                beta = beta +0.01*(i-self._batch_size//2)
            next_candidate = self._sampler.real_exp_pseudo_label_sample(search_node.space.unobserved_idxes, n_sample=1,mask=mask,cur_time=i,eta=eta,beta = beta,n_times=self._batch_size)
            mask.extend(next_candidate)
            candidate.extend(next_candidate)
        print("Designed is completed1")
        next_batch = [self._dataset[idx].category_value for idx in candidate]
      
            
            
        return next_batch

    def obersved_init_tree(self,expert_partition:Dict[str,List[List]],order:List[str],observed_idxes:List,pseudo_label:bool=False):
        Subspace.init(self._dataset)
        Node.init()
        self._root = Node(space = Subspace(self._dataset.all_idxes),leaf_size=self._leaf_size,depth=0)
        
        Node.build_tree_by_expert(self._root,expert_partition,order)
        print(f"The tree has {len(Node._all_nodes)} nodes and {len(Node._all_leaves)} leaves")
        ################################################################################

        observed_idxes = [x for x in observed_idxes if x in self._dataset.all_idxes]
        observed_values = [self._dataset[idx].observed_value for idx in observed_idxes]

        if len(observed_idxes) > self._num_init_samples:
            raise ValueError(f"The number of observed values {len(observed_idxes)} is greater than the init nums {self._num_init_samples}")

        if len(observed_idxes) > 0:
            logger.success("Loading train data as observed values")
            self._sampler.update_train(observed_idxes, torch.vstack(observed_values))
            self._sampler.init_model(self._sampler.train_x, self._sampler.train_y)
            
            for idx, value in zip(observed_idxes, observed_values):
                node = self.find_node_by_data_idx(idx)
                node.inc_incomplete_count()
                node.backprop(value.squeeze())
        
        if len(observed_idxes) == self._num_init_samples:
            logger.success(f"Already have {len(observed_idxes)} observed samples, which is == {self._num_init_samples}. Returning directly.")
            return observed_values

        else:
            remaining_samples = self._num_init_samples - len(observed_idxes)
            if pseudo_label:
                avg_pred_value_of_leaves: Dict[int, float] = {}
                max_idxes = []
                for leaf in Node._all_leaves:
                    avg_pred_value_of_leaves[leaf] = np.mean([self._dataset[idx].predict_value for idx in Node._all_nodes[leaf].space.idxes])
                topk_leaves = sorted(avg_pred_value_of_leaves.items(), key=lambda x: x[1], reverse=True)
                for leaf, _ in topk_leaves:
                    idx = np.argmax([self._dataset[idx].predict_value for idx in Node._all_nodes[leaf].space.idxes])
                    max_idxes.append(Node._all_nodes[leaf].space.idxes[idx])
                    if max_idxes == remaining_samples:
                        break

                candidates = max_idxes
            else:
                # Use random sampling to fill the gap
                random_leaves = random.sample(list(Node._all_leaves), remaining_samples)
                random_idxes = []
                for leaf in random_leaves:
                    random_idxes.append(random.choice(Node._all_nodes[leaf].space.idxes))
                candidates = random_idxes
            
            new_observed_values = [self._dataset[idx].observed_value for idx in candidates]
            observed_values += new_observed_values
            self._sampler.update_train(candidates, torch.vstack(new_observed_values))
            self._sampler.init_model(self._sampler.train_x, self._sampler.train_y)
            
            for candidate, value in zip(candidates, new_observed_values):
                node = self.find_node_by_data_idx(candidate)
                node.inc_incomplete_count()
                node.backprop(value.squeeze())
            
            return observed_values

    def pseudo_init_tree(self,expert_partition:Dict[str,List[List]],order:List[str],pseudo_label:bool=False):
        Subspace.init(self._dataset)
        Node.init()
        self._root = Node(space = Subspace(self._dataset.all_idxes),leaf_size=self._leaf_size,depth=0)
        
        Node.build_tree_by_expert(self._root,expert_partition,order)
        print(f"The tree has {len(Node._all_nodes)} nodes and {len(Node._all_leaves)} leaves")
        ################################################################################
        
        if pseudo_label:
            avg_pred_value_of_leaves:Dict[int,float] = {}
            max_idxes = []
            for leaf in Node._all_leaves:
                avg_pred_value_of_leaves[leaf] = np.mean([self._dataset[idx].predict_value for idx in Node._all_nodes[leaf].space.idxes])
            topk_leaves = sorted(avg_pred_value_of_leaves.items(),key=lambda x: x[1],reverse=True)[:self._num_init_samples]
            for leaf, _ in topk_leaves:
                # print(f"    Node space idxes: {Node._all_nodes[leaf].space.idxes}")
                idx = np.argmax([self._dataset[idx].predict_value for idx in Node._all_nodes[leaf].space.idxes])
                max_idxes.append(Node._all_nodes[leaf].space.idxes[idx])
            if len(max_idxes) < self._num_init_samples:
                needed_num = self._num_init_samples-len(max_idxes)
                to_sample = list(Node._all_leaves)
                random_leaves = []
                if len(to_sample) < needed_num:
                    while len(random_leaves) < needed_num:
                        random_leaves += to_sample
                    random_leaves = random.sample(random_leaves, needed_num)
                else:
                    random_leaves = random.sample(to_sample,needed_num)
                random_idxes = []

                for leaf in random_leaves:
                    random_idxes.append(random.choice(Node._all_nodes[leaf].space.idxes))
                max_idxes += random_idxes
            #candidates = random_idxes
            candidates = max_idxes
        else:
            random_leaves = random.sample(list(Node._all_leaves),self._num_init_samples)
            random_idxes = []
            for leaf in random_leaves:
                random_idxes.append(random.choice(Node._all_nodes[leaf].space.idxes))
            candidates = random_idxes
        observed_values = [self._dataset[idx].observed_value for idx in candidates]
        self._sampler.update_train(candidates,torch.vstack(observed_values))
        self._sampler.init_model(self._sampler.train_x,self._sampler.train_y)
        for idx in self._dataset.all_idxes:
            leaf = self.find_node_by_data_idx(idx)
            alpha = 1 / (1 + len(leaf.space.idxes))
            #alpha = 1
            self.find_node_by_data_idx(idx).pseudo_backprop(self._dataset[idx].predict_value.squeeze(),
                                                            alpha=alpha)
            
                                                            
        for candidate, value in zip(candidates, observed_values):
            node = self.find_node_by_data_idx(candidate)
            node.inc_incomplete_count()
            node.backprop(value.squeeze())
        # cb_scores = self._sampler.cb_score(self._root.space.idxes)
        # for cb, idx in zip(cb_scores, self._root.space.idxes):
        #     self.find_node_by_data_idx(idx).backprop(cb)
        return observed_values
    
    def sample_path(self, n_sample:int):
        nice_leaves = []
        while len(nice_leaves) < n_sample:
            search_node = self._root
            while not search_node.is_leaf:
                
                search_node = max(search_node.children, key=lambda x: x.cb)
                # print(f"    search node cb: {search_node.cb}")
            nice_leaves.append(search_node)
            search_node.inc_incomplete_count()
        return nice_leaves

    def random_sample_path(self, n_sample:int):
        """
        在树中完全随机地选择n_sample条路径，直到叶子节点。
        这主要用于MCTS的初始阶段，以保证探索的广度。
        """
        nice_leaves = []
        while len(nice_leaves) < n_sample:
            search_node = self._root
            while not search_node.is_leaf:
                search_node = random.choice(search_node.children)
            nice_leaves.append(search_node)
            search_node.inc_incomplete_count()
        return nice_leaves

    def _get_one_random_path(self):
        """
        辅助函数：从根节点随机走到一个叶子节点并返回，不更新任何计数。
        返回叶子节点
        """
        search_node = self._root
        while not search_node.is_leaf:
            if not search_node.children:
                break
            search_node = random.choice(search_node.children)
        return search_node

    def diversity_aware_random_sample(self, n_sample:int, n_candidates:int = 10):
        """
        实现一个带有“多样性”偏好的随机采样。
        它会生成n_candidates个候选路径集，然后根据一个评价函数选择最优的一个。

        Args:
            n_sample (int): 最终需要返回的路径数量 (例如: 5)。
            n_candidates (int): 生成的候选集数量 (例如: 10)。
        """
        # 定义价值函数
        # TODO: 不一定要线性
        step = (0.8 - 0.2) / (self.variable_nums - 1) if self.variable_nums > 1 else 0
        level_weights = {i: round(0.8 - i * step, 2) for i in range(self.variable_nums)}

        candidate_sets = []
        # 生成 n_candidates 个候选集，每个集包含 n_sample 条随机路径（叶子节点）
        for _ in range(n_candidates):
            current_set = [self._get_one_random_path() for _ in range(n_sample)]
            candidate_sets.append(current_set)

        best_set = None
        max_score = -1

        # 遍历所有候选集，为它们打分
        for path_set in candidate_sets:
            # 使用 set 来自动处理路径上的重复节点
            all_nodes_in_paths = set()
            for leaf_node in path_set:
                curr = leaf_node
                while curr is not None:
                    all_nodes_in_paths.add(curr)
                    curr = curr.parent

            # 按照层级统计不同节点的数量
            nodes_per_level = defaultdict(int)
            for node in all_nodes_in_paths:
                nodes_per_level[node._depth] += 1
            
            # 计算当前候选集的总分
            current_score = 0
            for level, count in nodes_per_level.items():
                # 使用 .get(level, 0.01) 来处理超出预设权重的深层节点
                weight = level_weights.get(level, 0.01) 
                current_score += count * weight
            
            # 更新最优选择
            if current_score > max_score:
                max_score = current_score
                best_set = path_set
        
        # 只对最终选出的最优路径集进行计数更新
        # if best_set:
        #     for leaf in best_set:
        #         leaf.inc_incomplete_count()
        
        return best_set
            
    def search(self, pseudo_label: bool = False, iteration_index: int = 1):
        if self.use_diverse_sample and iteration_index == 1:
            avg_pred_value_of_leaves:Dict[int,float] = {}
            max_idxes = []
            for leaf in Node._all_leaves:
                avg_pred_value_of_leaves[leaf] = np.mean(
                    [self._dataset[idx].predict_value for idx in Node._all_nodes[leaf].space.idxes]
                )
            topk_leaves = sorted(avg_pred_value_of_leaves.items(), key=lambda x: x[1], reverse=True)[:self._num_init_samples]
            for leaf, _ in topk_leaves:
                idx = np.argmax([self._dataset[idx].predict_value for idx in Node._all_nodes[leaf].space.idxes])
                max_idxes.append(Node._all_nodes[leaf].space.idxes[idx])
            if len(max_idxes) < self._num_init_samples:
                needed_num = self._num_init_samples-len(max_idxes)
                to_sample = list(Node._all_leaves)
                random_leaves = []
                if len(to_sample) < needed_num:
                    while len(random_leaves) < needed_num:
                        random_leaves += to_sample
                    random_leaves = random.sample(random_leaves, needed_num)
                else:
                    random_leaves = random.sample(to_sample,needed_num)
                random_idxes = []
                for leaf in random_leaves:
                    random_idxes.append(random.choice(Node._all_nodes[leaf].space.idxes))
                max_idxes += random_idxes

            candidates = max_idxes
            observed_values = [self._dataset[idx].observed_value for idx in candidates]
            self._sampler.update_train(candidates, torch.vstack(observed_values))
            self._sampler.init_model(self._sampler.train_x, self._sampler.train_y)

            # 遍历伪数据并更新树
            for idx in self._dataset.all_idxes:
                leaf = self.find_node_by_data_idx(idx)
                alpha = 1 / (1 + len(leaf.space.idxes))
                self.find_node_by_data_idx(idx).pseudo_backprop(self._dataset[idx].predict_value.squeeze(),
                                                                alpha=alpha)
                
                                                                
            for candidate, value in zip(candidates, observed_values):
                node = self.find_node_by_data_idx(candidate)
                node.inc_incomplete_count()
                node.backprop(value.squeeze())

        search_nodes = self.sample_path(self._batch_size)
        candidate = []
        for search_node in search_nodes:
            # 统计搜索次数
            self._leaf_visit_counts[search_node.id] += 1
                
            if pseudo_label:
                candidate.extend(self._sampler.pseudo_label_sample(search_node.space.unobserved_idxes, n_sample=1))
            else:
                candidate.extend(self._sampler.sample(search_node.space.unobserved_idxes, n_sample=1))

        observed_values = [self._dataset[idx].observed_value for idx in candidate]
        #按照observed value更新节点
        for node, v in zip(candidate, observed_values):
            self.find_node_by_data_idx(node).backprop(v.squeeze())
        # print(f"    candidate: {candidate}")
        # print(f"    observed_values: {observed_values}")
        return observed_values

    def get_search_statistics(self, order: List[str]) -> Dict[str, int]:
        """
        Retrieves the historical search counts for each combination/group (leaf node).

        Args:
            order (List[str]): The order of components used to build the tree,
                               needed to generate readable combination names.

        Returns:
            Dict[str, int]: A dictionary where keys are combination names and
                            values are their total visit counts.
        """
        # If no searches have happened, return empty dict
        if not self._leaf_visit_counts:
            return {}

        all_leaves = Node.get_all_leaves()
        # Create a mapping from node ID to its human-readable combination name
        id_to_combo_name = {}
        for leaf in all_leaves:
            path = leaf.path_from_root()
            combo_parts = []
            for node in path[1:]: # Skip root node
                parent = node.parent
                component_name = order[node._depth - 1]
                child_index = parent._children.index(node.id)
                combo_parts.append(f"{component_name}_Group{child_index}")
            id_to_combo_name[leaf.id] = " | ".join(combo_parts)

        # Build the final statistics dictionary
        statistics = {}
        for node_id, count in self._leaf_visit_counts.items():
            # It's possible a node was visited but is no longer a leaf (if the tree can be re-split).
            # We only report stats for current leaves.
            if node_id in id_to_combo_name:
                combo_name = id_to_combo_name[node_id]
                statistics[combo_name] = count
        
        return statistics

    def search_weighted(self):
        search_node = self._root
        leaf_nodes = Node.get_all_leaves()
        cbs = np.array([node.cb for node in leaf_nodes])
        cbs[cbs == float('inf')] = np.mean(cbs[cbs != float('inf')])
        tau = 0.7
        cbs_logits = torch.softmax(torch.tensor(cbs)/tau,dim=0)
        print(cbs_logits)
        #根据idx，给每个idx一个权重
        search_space_idx = search_node.space.unobserved_idxes
        search_space_weight = torch.zeros(len(search_space_idx))
        for i, idx in enumerate(search_space_idx):
            for node, cb in zip(leaf_nodes,cbs_logits):
                if idx in node.space.idxes:
                    search_space_weight[i] = cb
                    break
        cb_top5_node_id = [node.id for node in sorted(leaf_nodes,key=lambda x: x.cb,reverse=True)[:5]]
        
        candidate = self._sampler.pseudo_label_sample(search_space_idx, n_sample=self._batch_size,weight=search_space_weight)
        candidate_node_id = [self.find_node_by_data_idx(idx).id for idx in candidate]
        print(f"    ucb_top5_node_id: {cb_top5_node_id}")
        print(f"    ucb_top5: {[cb for cb in sorted(cbs_logits ,reverse=True)[:5]]}")
        print(f"    candidate_node_id: {candidate_node_id}")
        print(f"    candidate: {candidate}")
        observed_values = [self._dataset[idx].observed_value for idx in candidate]
        # 按照observed value更新节点
        for node, v in zip(candidate, observed_values):
            self.find_node_by_data_idx(node).backprop(v.squeeze())
            
        return observed_values
    
    def baseline_search(self,pseudo_label:bool=False):
        
        search_node = self._root
        # print(f"    unobserved length: {len(search_node.space.unobserved_idxes)}")
        if pseudo_label:
            candidate = self._sampler.pseudo_label_sample(search_node.space.unobserved_idxes, n_sample=self._batch_size)
        else:
            candidate = self._sampler.sample(search_node.space.unobserved_idxes, n_sample=self._batch_size)
        observed_values = [self._dataset[idx].observed_value for idx in candidate]
        candidate_node_id = [self.find_node_by_data_idx(idx).id for idx in candidate]
        # print(f"    candidate_node_id: {candidate_node_id}")
        return observed_values
    

    def pseudo_label_search(self):
        pass