import json
import os
from typing import Dict, List, Optional, Tuple

class HierarchicalTree:
    """
    A class to manage the hierarchical tree structure for inference testing.
    
    This class provides easy access to:
    - Operational codes for classification
    - Parent-child relationships for test inference
    - Full hierarchical structure
    - Level lookups
    
    Can load data from any iteration directory.
    """
    
    def __init__(self, iteration_dir: str = None, tree_file: str = None):
        """
        Initialize the hierarchical tree from a JSON file.
        
        Args:
            iteration_dir: Path to the iteration directory (e.g., "temp_files/iteration_04")
            tree_file: Direct path to the inference tree JSON file (overrides iteration_dir)
        """
        if tree_file:
            # Direct file path provided
            self.tree_file = tree_file
        elif iteration_dir:
            # Iteration directory provided
            self.tree_file = os.path.join(iteration_dir, "hierarchical_tree", "inference_tree.json")
        else:
            # Default to current directory
            self.tree_file = "inference_tree.json"
        
        # Check if file exists
        if not os.path.exists(self.tree_file):
            raise FileNotFoundError(f"Hierarchical tree file not found: {self.tree_file}")
        
        # Load the tree data
        with open(self.tree_file, 'r') as f:
            self.tree = json.load(f)
        
        self.lookup_tables = self.tree['lookup_tables']
        self.hierarchy = self.tree['hierarchy']
        self.metadata = self.tree['metadata']
        
        # Store the iteration info
        self.iteration_dir = iteration_dir
        self.iteration_name = os.path.basename(iteration_dir) if iteration_dir else "unknown"
    
    def get_operational_codes(self) -> List[str]:
        """
        Get all operational codes for classification.
        
        Returns:
            List of operational codes (level -1)
        """
        if 'operational' not in self.hierarchy:
            return []
        operational_nodes = self.hierarchy['operational']['nodes']
        return [node['code'] for node in operational_nodes]
    
    def get_parent_children_mapping(self) -> Dict[str, List[str]]:
        """
        Get parent-to-children mapping for test inference.
        
        Returns:
            Dictionary mapping parent codes to their children
        """
        return self.lookup_tables.get('parent_to_children', {})
    
    def get_topological_hierarchy(self) -> Dict[str, List[str]]:
        """
        Get full topological to operational mapping.
        
        Returns:
            Dictionary mapping topological codes to operational codes
        """
        return self.lookup_tables.get('topological_to_operational', {})
    
    def get_code_level(self, code: str) -> Optional[Dict]:
        """
        Get the level information for a given code.
        
        Args:
            code: The code to look up
            
        Returns:
            Dictionary with level information or None if not found
        """
        return self.lookup_tables.get('code_to_level', {}).get(code)
    
    def get_children(self, code: str) -> List[str]:
        """
        Get all children of a given code.
        
        Args:
            code: The parent code
            
        Returns:
            List of child codes
        """
        return self.lookup_tables.get('code_to_children', {}).get(code, [])
    
    def get_parents(self, code: str) -> List[str]:
        """
        Get all parents of a given code.
        
        Args:
            code: The child code
            
        Returns:
            List of parent codes
        """
        return self.lookup_tables.get('code_to_parents', {}).get(code, [])
    
    def get_codes_by_level(self, level: str) -> List[Dict]:
        """
        Get all codes at a specific level.
        
        Args:
            level: The level name ('topological', 'high_level', 'operational')
            
        Returns:
            List of code dictionaries
        """
        return self.lookup_tables.get('level_to_codes', {}).get(level, [])
    
    def get_tree_stats(self) -> Dict:
        """
        Get statistics about the hierarchical tree.
        
        Returns:
            Dictionary with tree statistics
        """
        return {
            'total_nodes': self.metadata.get('total_nodes', 0),
            'levels': self.metadata.get('levels', {}),
            'operational_codes_count': len(self.get_operational_codes()),
            'parent_children_mappings': len(self.get_parent_children_mapping()),
            'topological_mappings': len(self.get_topological_hierarchy()),
            'iteration': self.iteration_name,
            'created_from': self.metadata.get('created_from', 'unknown')
        }
    
    def find_code_in_hierarchy(self, code: str) -> Optional[Dict]:
        """
        Find a code in the hierarchy and return its full context.
        
        Args:
            code: The code to find
            
        Returns:
            Dictionary with full context or None if not found
        """
        level_info = self.get_code_level(code)
        if not level_info:
            return None
        
        return {
            'code': code,
            'level': level_info['level'],
            'type': level_info['type'],
            'node_id': level_info['node_id'],
            'parents': self.get_parents(code),
            'children': self.get_children(code)
        }
    
    def get_inference_candidates(self, test_code: str) -> List[str]:
        """
        Get candidate codes for inference testing.
        
        Args:
            test_code: The test code to find candidates for
            
        Returns:
            List of candidate codes for inference
        """
        # First, try to find exact match
        if test_code in self.get_operational_codes():
            return [test_code]
        
        # If not found, return all operational codes as candidates
        return self.get_operational_codes()
    
    def print_tree_summary(self):
        """
        Print a summary of the hierarchical tree.
        """
        stats = self.get_tree_stats()
        print(f"🌳 Hierarchical Tree Summary ({self.iteration_name}):")
        print(f"   Total nodes: {stats['total_nodes']}")
        print(f"   Operational codes: {stats['operational_codes_count']}")
        print(f"   Parent-children mappings: {stats['parent_children_mappings']}")
        print(f"   Topological mappings: {stats['topological_mappings']}")
        print(f"   Created from: {stats['created_from']}")
        print()
        print(f"📊 Levels:")
        for level, count in stats['levels'].items():
            print(f"   {level}: {count} nodes")
    
    @classmethod
    def from_iteration(cls, iteration_dir: str):
        """
        Create a HierarchicalTree instance from an iteration directory.
        
        Args:
            iteration_dir: Path to the iteration directory
            
        Returns:
            HierarchicalTree instance
        """
        return cls(iteration_dir=iteration_dir)
    
    @classmethod
    def from_file(cls, tree_file: str):
        """
        Create a HierarchicalTree instance from a direct file path.
        
        Args:
            tree_file: Path to the inference tree JSON file
            
        Returns:
            HierarchicalTree instance
        """
        return cls(tree_file=tree_file)
