#!/usr/bin/env python3
"""
Hierarchical Tree Creator for Multi-Iteration Schema Induction Pipeline

This module creates a comprehensive hierarchical tree structure from the final iteration
of the schema induction pipeline, including all levels and relationships for inference testing.
"""

import os
import json
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Any
from pathlib import Path

class HierarchicalTreeCreator:
    """
    Creates a comprehensive hierarchical tree structure for inference testing
    """
    
    def __init__(self, iteration_dir: str):
        """
        Initialize the hierarchical tree creator
        
        Args:
            iteration_dir: Path to the final iteration directory
        """
        self.iteration_dir = iteration_dir
        self.tree_data = {}
        
    def create_hierarchical_tree(self) -> Dict[str, Any]:
        """
        Create the complete hierarchical tree structure
        
        Returns:
            Dictionary containing the hierarchical tree
        """
        print("🌳 Creating hierarchical tree structure...")
        
        # Load all the data
        self._load_hierarchical_data()
        
        # Create the tree structure
        hierarchical_tree = self._build_tree_structure()
        
        # Create practical inference tree
        inference_tree = self._create_inference_tree(hierarchical_tree)
        
        # Save both trees
        self._save_trees(hierarchical_tree, inference_tree)
        
        return {
            "hierarchical_tree": hierarchical_tree,
            "inference_tree": inference_tree,
            "saved_files": self._get_saved_files()
        }
    
    def _load_hierarchical_data(self):
        """Load all hierarchical data from the iteration directory"""
        print("   📊 Loading hierarchical data...")
        
        # 1. Topological sort (abstract levels 0-2)
        topological_path = os.path.join(self.iteration_dir, "topologically_sorted_graph", "topological_sort.parquet")
        if os.path.exists(topological_path):
            self.tree_data['topological'] = pd.read_parquet(topological_path)
            print(f"   - Topological nodes: {len(self.tree_data['topological'])}")
        else:
            print(f"   ⚠️  Topological sort file not found: {topological_path}")
            self.tree_data['topological'] = pd.DataFrame()
        
        # 2. High-level codes (parent codes)
        high_level_path = os.path.join(self.iteration_dir, "high_level_codes", "high_level_codes.parquet")
        if os.path.exists(high_level_path):
            self.tree_data['high_level'] = pd.read_parquet(high_level_path)
            print(f"   - High-level parents: {len(self.tree_data['high_level'])}")
        else:
            print(f"   ⚠️  High-level codes file not found: {high_level_path}")
            self.tree_data['high_level'] = pd.DataFrame()
        
        # 3. Final operational codes (level -1)
        operational_path = os.path.join(self.iteration_dir, "topologically_sorted_graph", "code_datapoints_enhanced.parquet")
        if os.path.exists(operational_path):
            self.tree_data['operational'] = pd.read_parquet(operational_path)
            print(f"   - Final operational codes: {len(self.tree_data['operational'])}")
        else:
            print(f"   ⚠️  Operational codes file not found: {operational_path}")
            self.tree_data['operational'] = pd.DataFrame()
        
        # 4. Parent-children mapping
        mapping_path = os.path.join(self.iteration_dir, "topologically_sorted_graph", "mapping_report.json")
        if os.path.exists(mapping_path):
            with open(mapping_path, 'r') as f:
                self.tree_data['mapping_report'] = json.load(f)
            print(f"   - Mapping relationships: {len(self.tree_data['mapping_report'])}")
        else:
            print(f"   ⚠️  Mapping report file not found: {mapping_path}")
            self.tree_data['mapping_report'] = []
    
    def _build_tree_structure(self) -> Dict[str, Any]:
        """Build the hierarchical tree structure"""
        print("   🌳 Building tree structure...")
        
        hierarchical_tree = {
            'metadata': {
                'total_nodes': 0,
                'levels': {},
                'description': 'Complete hierarchical tree for inference testing',
                'created_from': os.path.basename(self.iteration_dir),
                'created_at': pd.Timestamp.now().isoformat()
            },
            'levels': {}
        }
        
        # Level 1: Topological nodes (abstract conceptual level)
        if not self.tree_data['topological'].empty:
            print("   - Level 1: Topological nodes (abstract conceptual)")
            topological_nodes = []
            for idx, row in self.tree_data['topological'].iterrows():
                node = {
                    'id': f'topological_{idx}',
                    'level': row['level'],
                    'type': 'topological',
                    'code': row['node'],
                    'children': []
                }
                topological_nodes.append(node)
            
            hierarchical_tree['levels']['topological'] = {
                'level': 'abstract_conceptual',
                'type': 'topological_sort',
                'nodes': topological_nodes
            }
        
        # Level 2: High-level parent codes
        if not self.tree_data['high_level'].empty:
            print("   - Level 2: High-level parent codes")
            high_level_nodes = []
            for _, row in self.tree_data['high_level'].iterrows():
                node = {
                    'id': f'high_level_{row["cluster_id"]}',
                    'level': 2,
                    'type': 'high_level_parent',
                    'code': row['high_level_code'],
                    'cluster_id': row['cluster_id'],
                    'cluster_size': row.get('cluster_size', len(row.get('source_codes', []))),
                    'source_codes': row['source_codes'].tolist() if isinstance(row['source_codes'], np.ndarray) else row['source_codes'],
                    'children': []
                }
                high_level_nodes.append(node)
            
            hierarchical_tree['levels']['high_level'] = {
                'level': 2,
                'type': 'high_level_parent',
                'nodes': high_level_nodes
            }
        
        # Level 3: Final operational codes (level -1)
        if not self.tree_data['operational'].empty:
            print("   - Level 3: Final operational codes (level -1)")
            operational_nodes = []
            for idx, row in self.tree_data['operational'].iterrows():
                node = {
                    'id': f'operational_{idx}',
                    'level': row['level'],
                    'type': 'operational',
                    'code': row['code'],
                    'datapoint': row['datapoint'],
                    'global_frequency': row['global_frequency'],
                    'incoming_edges': row['incoming_edges'],
                    'merge_score': row['merge_score'],
                    'parents': []
                }
                operational_nodes.append(node)
            
            hierarchical_tree['levels']['operational'] = {
                'level': -1,
                'type': 'operational',
                'nodes': operational_nodes
            }
        
        # Build parent-child relationships
        self._build_relationships(hierarchical_tree)
        
        # Update metadata
        total_nodes = sum(len(level_data['nodes']) for level_data in hierarchical_tree['levels'].values())
        hierarchical_tree['metadata']['total_nodes'] = total_nodes
        hierarchical_tree['metadata']['levels'] = {
            level_name: len(level_data['nodes']) 
            for level_name, level_data in hierarchical_tree['levels'].items()
        }
        
        print(f"   - Total nodes: {total_nodes}")
        for level_name, count in hierarchical_tree['metadata']['levels'].items():
            print(f"   - {level_name}: {count} nodes")
        
        return hierarchical_tree
    
    def _build_relationships(self, hierarchical_tree: Dict[str, Any]):
        """Build parent-child relationships"""
        print("   - Building parent-child relationships...")
        
        # Create lookup dictionaries
        topological_lookup = {}
        high_level_lookup = {}
        operational_lookup = {}
        
        if 'topological' in hierarchical_tree['levels']:
            topological_lookup = {node['code']: node for node in hierarchical_tree['levels']['topological']['nodes']}
        
        if 'high_level' in hierarchical_tree['levels']:
            high_level_lookup = {node['cluster_id']: node for node in hierarchical_tree['levels']['high_level']['nodes']}
        
        if 'operational' in hierarchical_tree['levels']:
            operational_lookup = {node['code']: node for node in hierarchical_tree['levels']['operational']['nodes']}
        
        # Build relationships from high-level to operational
        if 'high_level' in hierarchical_tree['levels'] and 'operational' in hierarchical_tree['levels']:
            for _, row in self.tree_data['high_level'].iterrows():
                parent_id = row['cluster_id']
                source_codes = row['source_codes']
                
                if parent_id in high_level_lookup:
                    parent_node = high_level_lookup[parent_id]
                    
                    # Add children (operational codes)
                    for source_code in source_codes:
                        if source_code in operational_lookup:
                            op_node = operational_lookup[source_code]
                            parent_node['children'].append(op_node['id'])
                            op_node['parents'].append(parent_id)
        
        # Build relationships from topological to high-level
        if 'topological' in hierarchical_tree['levels'] and 'high_level' in hierarchical_tree['levels']:
            for mapping in self.tree_data['mapping_report']:
                if 'target' in mapping and 'sources' in mapping:
                    target = mapping['target']
                    sources = mapping['sources']
                    
                    if target in topological_lookup:
                        topo_node = topological_lookup[target]
                        # Add high-level children
                        for source in sources:
                            for high_node in hierarchical_tree['levels']['high_level']['nodes']:
                                if high_node['code'] == source:
                                    topo_node['children'].append(high_node['id'])
                                    break
    
    def _create_inference_tree(self, hierarchical_tree: Dict[str, Any]) -> Dict[str, Any]:
        """Create a practical inference tree with lookup tables"""
        print("   🔧 Creating practical inference tree...")
        
        inference_tree = {
            'metadata': {
                'description': 'Practical hierarchical tree for inference testing',
                'total_nodes': hierarchical_tree['metadata']['total_nodes'],
                'levels': hierarchical_tree['metadata']['levels'],
                'created_from': hierarchical_tree['metadata']['created_from'],
                'created_at': hierarchical_tree['metadata']['created_at']
            },
            'lookup_tables': {},
            'hierarchy': hierarchical_tree['levels'],
            'inference_helpers': {}
        }
        
        # Create lookup tables
        self._create_lookup_tables(hierarchical_tree, inference_tree)
        
        return inference_tree
    
    def _create_lookup_tables(self, hierarchical_tree: Dict[str, Any], inference_tree: Dict[str, Any]):
        """Create lookup tables for easy access"""
        # Initialize lookup tables
        code_to_level = {}
        code_to_parents = {}
        code_to_children = {}
        level_to_codes = {}
        parent_to_children = {}
        topological_to_operational = {}
        
        # Process each level
        for level_name, level_data in hierarchical_tree['levels'].items():
            level_to_codes[level_name] = []
            
            for node in level_data['nodes']:
                code = node['code']
                node_id = node['id']
                
                # Code to level mapping
                code_to_level[code] = {
                    'level': level_name,
                    'node_id': node_id,
                    'type': node['type']
                }
                
                # Level to codes mapping
                level_to_codes[level_name].append({
                    'code': code,
                    'node_id': node_id,
                    'type': node['type']
                })
                
                # Parent-child relationships
                if 'children' in node:
                    code_to_children[code] = node['children']
                if 'parents' in node:
                    code_to_parents[code] = node['parents']
        
        # Create parent-to-children mapping (for test inference)
        if 'high_level' in hierarchical_tree['levels']:
            for node in hierarchical_tree['levels']['high_level']['nodes']:
                parent_code = node['code']
                children = []
                for child_id in node['children']:
                    # Find the child code
                    for op_node in hierarchical_tree['levels']['operational']['nodes']:
                        if op_node['id'] == child_id:
                            children.append(op_node['code'])
                            break
                parent_to_children[parent_code] = children
        
        # Create topological to operational mapping (for full hierarchy)
        if 'topological' in hierarchical_tree['levels']:
            for topo_node in hierarchical_tree['levels']['topological']['nodes']:
                topo_code = topo_node['code']
                operational_codes_list = []
                
                # Get all operational codes under this topological node
                for child_id in topo_node['children']:
                    # Find high-level children
                    for high_node in hierarchical_tree['levels']['high_level']['nodes']:
                        if high_node['id'] == child_id:
                            # Get operational children of this high-level node
                            for op_child_id in high_node['children']:
                                for op_node in hierarchical_tree['levels']['operational']['nodes']:
                                    if op_node['id'] == op_child_id:
                                        operational_codes_list.append(op_node['code'])
                                        break
                            break
                
                topological_to_operational[topo_code] = operational_codes_list
        
        # Store all lookup tables
        inference_tree['lookup_tables'] = {
            'code_to_level': code_to_level,
            'code_to_parents': code_to_parents,
            'code_to_children': code_to_children,
            'level_to_codes': level_to_codes,
            'parent_to_children': parent_to_children,
            'topological_to_operational': topological_to_operational
        }
        
        # Create inference helpers
        inference_tree['inference_helpers'] = {
            'get_operational_codes': {
                'description': 'Get all operational codes for classification',
                'count': len([node for node in hierarchical_tree['levels']['operational']['nodes']])
            },
            'get_parent_children_mapping': {
                'description': 'Get parent-to-children mapping for test inference',
                'count': len(parent_to_children)
            },
            'get_topological_hierarchy': {
                'description': 'Get full topological to operational mapping',
                'count': len(topological_to_operational)
            },
            'get_code_level': {
                'description': 'Get the level of any code',
                'count': len(code_to_level)
            }
        }
    
    def _save_trees(self, hierarchical_tree: Dict[str, Any], inference_tree: Dict[str, Any]):
        """Save both tree structures to files"""
        print("   💾 Saving hierarchical trees...")
        
        # Create hierarchical_tree directory
        tree_dir = os.path.join(self.iteration_dir, "hierarchical_tree")
        os.makedirs(tree_dir, exist_ok=True)
        
        # Save complete hierarchical tree
        hierarchical_file = os.path.join(tree_dir, "hierarchical_tree_for_inference.json")
        with open(hierarchical_file, 'w') as f:
            json.dump(hierarchical_tree, f, indent=2)
        
        # Save practical inference tree
        inference_file = os.path.join(tree_dir, "inference_tree.json")
        with open(inference_file, 'w') as f:
            json.dump(inference_tree, f, indent=2)
        
        # Note: Python class is now in utils/hierarchical_tree_class.py
        
        print(f"   ✅ Trees saved to: {tree_dir}")
        print(f"   - hierarchical_tree_for_inference.json")
        print(f"   - inference_tree.json")
    
    def _get_saved_files(self) -> List[str]:
        """Get list of saved files"""
        tree_dir = os.path.join(self.iteration_dir, "hierarchical_tree")
        return [
            os.path.join(tree_dir, "hierarchical_tree_for_inference.json"),
            os.path.join(tree_dir, "inference_tree.json")
        ]
