from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch.nn as nn
import torch
import gymnasium as gym

class OptimizedHierarchicalBPlusFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, 
                 observation_space: gym.spaces.Box, 
                 feature_dim: int = 256,
                 values_per_node: int = 4,
                 num_ops: int = 6,
                 num_heads: int = 4,
                 dropout: float = 0.1,
                 max_levels: int = 10):
        super(OptimizedHierarchicalBPlusFeatureExtractor, self).__init__(
            observation_space, 
            feature_dim + num_ops
        )
        self.features_per_node = 3
        self.values_per_node = values_per_node
        self.children_per_node = values_per_node + 1
        self.num_ops = num_ops
        self.feature_dim = feature_dim
        self.max_levels = max_levels
        self.debug = False 
        
        self.level_structure = self._compute_level_structure(observation_space.shape)
        
        self.transformer = nn.TransformerEncoderLayer(
            d_model=feature_dim,
            nhead=num_heads,
            dim_feedforward=feature_dim,
            dropout=dropout,
            batch_first=True
        )
        self.linear = nn.Linear(feature_dim, feature_dim)
        self.leaf_embedding = nn.Linear(self.features_per_node, feature_dim)
        self.node_combiner = nn.Linear(feature_dim * self.children_per_node + self.features_per_node, feature_dim)
        self.level_embeddings = nn.Parameter(torch.randn(max_levels, 1, feature_dim))
        self.level_norm = nn.LayerNorm(feature_dim)
        self.node_norm = nn.LayerNorm(feature_dim * self.children_per_node + self.features_per_node)

    def _compute_level_structure(self, obs):
        """
        Pre-compute the structure of each level in the tree.
        Calculate leaf start index and propagate upwards to get all index ranges for all nodes.
        """
        num_levels = 0
        idx = 1
        obs_without_actions = obs[0] - self.num_ops
        # Determine the number of levels and leaf start index
        while idx < obs_without_actions:
            if idx * self.children_per_node > obs_without_actions:
                break
            idx *= self.children_per_node
            num_levels += 1

        level_structure = []
        current_value_end = obs_without_actions

        for level in range(num_levels, -1, -1):
            nodes_this_level = idx if level == num_levels else idx // (self.children_per_node ** (num_levels - level))
            parent_nodes = nodes_this_level // self.children_per_node

            value_start_idx = current_value_end - nodes_this_level * self.features_per_node

            level_info = {
                'num_nodes': nodes_this_level,
                'num_parents': parent_nodes,
                'value_start_idx': value_start_idx,
                'values_per_level': nodes_this_level * self.features_per_node,
                'value_end_idx': current_value_end
            }

            level_structure.append(level_info)
            current_value_end = value_start_idx

        return level_structure


    def _get_empty_mask(self, node_values):
        empty = node_values[..., 0] == 0
        return empty

    def process_level(self, level_info, current_embeddings, tree_data, level_idx):
        
        batch_size = current_embeddings.shape[0]
        num_parents = level_info['num_parents']
        
        if num_parents == 0:  # root node have to pass
            return current_embeddings


        parent_values_start = level_info['value_start_idx']
        parent_values = tree_data[:, parent_values_start - num_parents * self.features_per_node:parent_values_start]
        parent_values = parent_values.view(batch_size, num_parents, self.features_per_node)
        empty_mask = self._get_empty_mask(parent_values)

        output_embeddings = torch.zeros(
            batch_size, num_parents, self.feature_dim,
            device=current_embeddings.device
        )

        if (~empty_mask).any():
            non_empty_indices = torch.nonzero(~empty_mask)
            non_empty_parents = parent_values[~empty_mask]
            

            grouped_children = current_embeddings.view(batch_size, -1, self.children_per_node, self.feature_dim)
            non_empty_children = grouped_children[non_empty_indices[:, 0], non_empty_indices[:, 1]]
            
            children_flat = non_empty_children.view(-1, self.children_per_node * self.feature_dim)
            
            combined = torch.cat([children_flat, non_empty_parents], dim=1)
            #combined = self.node_norm(combined)
            parent_embeddings = self.node_combiner(combined)
            
            level_embedding = self.level_embeddings[level_idx].expand(len(parent_embeddings), -1)
            parent_embeddings = parent_embeddings + level_embedding
            parent_embeddings = self.level_norm(parent_embeddings)
            
            transformed = self.transformer(parent_embeddings.unsqueeze(1)).squeeze(1)
            #transformed = self.linear(parent_embeddings.unsqueeze(1)).squeeze(1) # linear version
            output_embeddings[non_empty_indices[:, 0], non_empty_indices[:, 1]] = transformed

        return output_embeddings

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        
        batch_size = obs.shape[0]
        ops = obs[:, :self.num_ops]
        tree_data = obs[:, self.num_ops:]
        tree_data = (tree_data - tree_data.min(dim=1, keepdim=True).values) / (tree_data.max(dim=1, keepdim=True).values - tree_data.min(dim=1, keepdim=True).values + 1e-8)

        # Process leaf nodes separately
        num_leaf_nodes = self.level_structure[0]['num_nodes'] if self.level_structure else (tree_data.shape[1] // self.features_per_node)
        leaf_values = tree_data[:, -num_leaf_nodes * self.features_per_node:].view(
            batch_size, num_leaf_nodes, self.features_per_node
        )
        
        leaf_embeddings = torch.zeros(batch_size, num_leaf_nodes, self.feature_dim, device=obs.device)
        
        empty_mask = self._get_empty_mask(leaf_values)
        if (~empty_mask).any():
            non_empty_leaves = leaf_values[~empty_mask]
            non_empty_indices = torch.nonzero(~empty_mask)
            
            embeddings = self.leaf_embedding(non_empty_leaves)
            
            level_embedding = self.level_embeddings[0].expand(len(embeddings), -1)
            embeddings = embeddings + level_embedding
            embeddings = self.level_norm(embeddings)
            
            leaf_embeddings[non_empty_indices[:, 0], non_empty_indices[:, 1]] = embeddings
        
        current_embeddings = leaf_embeddings
        
        for level_idx, level_info in enumerate(self.level_structure):
            current_embeddings = self.process_level(
                level_info,
                current_embeddings,
                tree_data,
                level_idx
            )
        
        root_embedding = current_embeddings.squeeze(1)
        return torch.cat([root_embedding, ops], dim=1)