#!/usr/bin/env python3
"""
Graph-based Retrieval Module

This module implements efficient graph-based retrieval using the optimized parquet-based
clique storage for fast access to hierarchical relationships and graph structures.

Features:
1. Fast clique loading using parquet files
2. Node-to-clique mapping for quick lookups
3. Hierarchical relationship retrieval (parents, children, siblings)
4. Global frequency-based ranking
5. Top-k retrieval with frequency weighting
"""

import os
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Set, Tuple, Optional
import networkx as nx
from collections import defaultdict

class GraphRetriever:
    """
    Efficient graph-based retriever using parquet-optimized clique storage
    """
    
    def __init__(self, cliques_dir: str, mapping_dir: str, top_k: int = 5):
        """
        Initialize the graph retriever
        
        Args:
            cliques_dir: Directory containing clique parquet files
            mapping_dir: Directory containing code frequency mappings
            top_k: Number of top results to return (default: 5)
        """
        self.cliques_dir = cliques_dir
        self.mapping_dir = mapping_dir
        self.top_k = min(top_k, 10)  # Cap at 10
        
        # Load clique data
        self.clique_data = self._load_clique_data()
        
        # Load code frequencies
        self.code_frequencies = self._load_code_frequencies()
        
        # Build lookup indices for fast access
        self._build_indices()
        
    def _load_clique_data(self) -> Dict[str, pd.DataFrame]:
        """Load all clique data from parquet files"""
        if not os.path.exists(self.cliques_dir):
            raise FileNotFoundError(f"Cliques directory not found: {self.cliques_dir}")
        
        data = {}
        files_to_load = [
            'clique_metadata.parquet',
            'clique_nodes.parquet', 
            'clique_edges.parquet',
            'clique_hierarchies.parquet'
        ]
        
        for filename in files_to_load:
            filepath = os.path.join(self.cliques_dir, filename)
            if os.path.exists(filepath):
                df_name = filename.replace('.parquet', '').replace('clique_', '')
                data[df_name] = pd.read_parquet(filepath)
        
        return data
    
    def _load_code_frequencies(self) -> Dict[str, int]:
        """Load code frequencies from mapping directory"""
        freq_path = os.path.join(self.mapping_dir, "code_frequencies.parquet")
        if not os.path.exists(freq_path):
            raise FileNotFoundError(f"Code frequencies file not found: {freq_path}")
        
        freq_df = pd.read_parquet(freq_path)
        return {row['code']: row['global_frequency'] for _, row in freq_df.iterrows()}
    
    def _build_indices(self):
        """Build lookup indices for fast access"""
        # Node to clique mapping
        self.node_to_cliques = defaultdict(list)
        if 'nodes' in self.clique_data:
            for _, row in self.clique_data['nodes'].iterrows():
                self.node_to_cliques[row['node']].append(row['clique_id'])
        
        # Clique size mapping
        self.clique_sizes = {}
        if 'metadata' in self.clique_data:
            for _, row in self.clique_data['metadata'].iterrows():
                self.clique_sizes[row['clique_id']] = row['size']
    
    def get_cliques_for_node(self, node: str) -> List[int]:
        """Get all clique IDs that contain a specific node"""
        return self.node_to_cliques.get(node, [])
    
    def get_node_relationships(self, node: str, clique_id: int) -> Dict[str, Any]:
        """Get hierarchical relationships for a node in a specific clique"""
        if 'hierarchies' not in self.clique_data:
            return {}
        
        node_data = self.clique_data['hierarchies'][
            (self.clique_data['hierarchies']['node'] == node) & 
            (self.clique_data['hierarchies']['clique_id'] == clique_id)
        ]
        
        if node_data.empty:
            return {}
        
        row = node_data.iloc[0]
        return {
            'parents': json.loads(row['parents']),
            'children': json.loads(row['children']),
            'siblings': json.loads(row['siblings']),
            'num_parents': row['num_parents'],
            'num_children': row['num_children'],
            'num_siblings': row['num_siblings']
        }
    
    def get_global_frequency(self, code: str) -> int:
        """Get global frequency for a code"""
        return self.code_frequencies.get(code, 0)
    
    def find_related_nodes(self, query_node: str) -> List[Dict[str, Any]]:
        """
        Find all related nodes (parents, children, siblings) for a given code
        
        Args:
            query_node: Code name to find relationships for
            
        Returns:
            List of related nodes with metadata and global frequency
        """
        related_nodes = []
        cliques = self.get_cliques_for_node(query_node)
        
        for clique_id in cliques:
            relationships = self.get_node_relationships(query_node, clique_id)
            
            # Get parents
            for parent in relationships.get('parents', []):
                global_freq = self.get_global_frequency(parent)
                related_nodes.append({
                    'code': parent,
                    'relationship': 'parent',
                    'clique_id': clique_id,
                    'clique_size': self.clique_sizes.get(clique_id, 0),
                    'global_frequency': global_freq
                })
            
            # Get children
            for child in relationships.get('children', []):
                global_freq = self.get_global_frequency(child)
                related_nodes.append({
                    'code': child,
                    'relationship': 'child',
                    'clique_id': clique_id,
                    'clique_size': self.clique_sizes.get(clique_id, 0),
                    'global_frequency': global_freq
                })
            
            # Get siblings
            for sibling in relationships.get('siblings', []):
                global_freq = self.get_global_frequency(sibling)
                related_nodes.append({
                    'code': sibling,
                    'relationship': 'sibling',
                    'clique_id': clique_id,
                    'clique_size': self.clique_sizes.get(clique_id, 0),
                    'global_frequency': global_freq
                })
        
        # Remove duplicates (keep the one with highest global frequency)
        unique_nodes = {}
        for node_info in related_nodes:
            code = node_info['code']
            if code not in unique_nodes or node_info['global_frequency'] > unique_nodes[code]['global_frequency']:
                unique_nodes[code] = node_info
        
        # Sort by global frequency (descending) and return top_k
        sorted_nodes = sorted(unique_nodes.values(), key=lambda x: x['global_frequency'], reverse=True)
        return sorted_nodes[:self.top_k]
    
    def retrieve_graph_codes(self, source_code: str) -> List[Dict[str, Any]]:
        """
        Retrieve related codes based on graph relationships, ranked by global frequency
        
        Args:
            source_code: The code name to find relationships for
            
        Returns:
            List of top k related codes sorted by global frequency
            Format: code, relationship, global_frequency, clique_id, clique_size
        """
        try:
            # Find all related nodes (parents, children, siblings)
            related_nodes = self.find_related_nodes(source_code)
            
            # Return simplified results with global frequency ranking
            results = []
            for node_info in related_nodes:
                results.append({
                    'code': node_info['code'],
                    'relationship': node_info['relationship'],
                    'global_frequency': node_info['global_frequency'],
                    'clique_id': node_info['clique_id'],
                    'clique_size': node_info['clique_size']
                })
            
            return results
            
        except Exception as e:
            return []
    
    def get_clique_context(self, clique_id: int) -> Dict[str, Any]:
        """Get comprehensive context for a clique"""
        if 'metadata' not in self.clique_data:
            return {}
        
        clique_meta = self.clique_data['metadata'][self.clique_data['metadata']['clique_id'] == clique_id]
        if clique_meta.empty:
            return {}
        
        meta_row = clique_meta.iloc[0]
        
        # Get nodes in this clique
        clique_nodes = self.clique_data['nodes'][self.clique_data['nodes']['clique_id'] == clique_id]
        nodes = clique_nodes['node'].tolist()
        
        # Get edges in this clique
        clique_edges = self.clique_data['edges'][self.clique_data['edges']['clique_id'] == clique_id]
        edges = [(row['source'], row['target']) for _, row in clique_edges.iterrows()]
        
        return {
            'clique_id': clique_id,
            'size': meta_row['size'],
            'num_edges': meta_row['num_edges'],
            'density': meta_row['density'],
            'is_strongly_connected': meta_row['is_strongly_connected'],
            'nodes': nodes,
            'edges': edges
        }
    
    def get_graph_statistics(self) -> Dict[str, Any]:
        """Get statistics about the graph structure"""
        if not self.clique_data:
            return {}
        
        stats = {
            'total_cliques': len(self.clique_data.get('metadata', [])),
            'total_nodes': len(self.clique_data.get('nodes', [])),
            'total_edges': len(self.clique_data.get('edges', [])),
            'unique_nodes': len(set(self.clique_data.get('nodes', [])['node'].tolist())) if 'nodes' in self.clique_data else 0,
            'total_codes_with_frequencies': len(self.code_frequencies)
        }
        
        if 'metadata' in self.clique_data:
            metadata = self.clique_data['metadata']
            stats.update({
                'avg_clique_size': metadata['size'].mean(),
                'max_clique_size': metadata['size'].max(),
                'min_clique_size': metadata['size'].min(),
                'avg_density': metadata['density'].mean(),
                'strongly_connected_cliques': metadata['is_strongly_connected'].sum()
            })
        
        return stats


# Example usage and testing
if __name__ == "__main__":
    # Test the graph retriever
    cliques_dir = "../temp_files/topologically_sorted_graph/cliques"
    mapping_dir = "../temp_files/topologically_sorted_graph/datapoint_code_mapping"
    
    if os.path.exists(cliques_dir) and os.path.exists(mapping_dir):
        retriever = GraphRetriever(cliques_dir, mapping_dir, top_k=5)
        
        print("📊 Graph Statistics:")
        stats = retriever.get_graph_statistics()
        for key, value in stats.items():
            print(f"   {key}: {value}")
        
        # Test retrieval with a sample code
        if retriever.node_to_cliques:
            # Get first available code
            test_code = list(retriever.node_to_cliques.keys())[0]
            print(f"\n🧪 Testing with code: {test_code[:60]}...")
            
            results = retriever.retrieve_graph_codes(test_code)
            
            print(f"\n🔍 Graph Retrieval Results for {test_code[:30]}...:")
            print(f"   Found {len(results)} related codes")
            
            for i, result in enumerate(results):
                print(f"   {i+1}. {result['code'][:50]}...")
                print(f"      Relationship: {result['relationship']}")
                print(f"      Global Frequency: {result['global_frequency']}")
                print(f"      Clique ID: {result['clique_id']}")
                print()
        else:
            print("❌ No nodes found in cliques")
    else:
        print(f"❌ Required directories not found:")
        print(f"   Cliques: {cliques_dir}")
        print(f"   Mapping: {mapping_dir}")
