import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.utils import scatter
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
import json
from groq import Groq

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(message)s')

class GradientFeatureExtractor:
    """
    Extracts meaningful, interpretable features from gradients for LLM explanation.
    
    These features are designed to be:
    1. Technically meaningful for graph learning
    2. Interpretable by humans and LLMs
    3. Capturing different aspects of gradient behavior
    """
    
    def __init__(self, device):
        self.device = device
        
    def extract_features(self, g, h, edge_index, predictions=None, true_labels=None, epoch=0):
        """
        Extract 6 innovative gradient features for interpretability:
        
        1. Gradient Conflict Intensity - How much this node disagrees with neighbors
        2. Learning Trajectory Stability - How consistent the gradient direction is
        3. Multi-hop Influence Strength - How much this node affects distant nodes
        4. Prediction Confidence vs Gradient Magnitude - Relationship between confidence and learning signal
        5. Topological Learning Role - Whether this node is a learning hub, bridge, or follower
        6. Correction Receptiveness - How amenable this node is to GGT corrections
        """
        
        num_nodes = g.size(0)
        features = {}
        
        # Basic gradient computations
        g_norm = torch.norm(g, dim=1)
        
        # Feature 1: Gradient Conflict Intensity
        neighbor_gradients = scatter(g[edge_index[1]], edge_index[0], dim=0, reduce='mean')
        gradient_alignment = F.cosine_similarity(g, neighbor_gradients, dim=1)
        conflict_intensity = self._compute_conflict_intensity(g_norm, gradient_alignment)
        features['conflict_intensity'] = conflict_intensity
        
        # Feature 2: Learning Trajectory Stability  
        trajectory_stability = self._compute_trajectory_stability(g, edge_index)
        features['trajectory_stability'] = trajectory_stability
        
        # Feature 3: Multi-hop Influence Strength
        influence_strength = self._compute_multihop_influence(g, edge_index)
        features['influence_strength'] = influence_strength
        
        # Feature 4: Prediction Confidence vs Gradient Magnitude
        if predictions is not None:
            confidence_gradient_relationship = self._compute_confidence_gradient_relationship(
                g_norm, predictions
            )
            features['confidence_gradient_rel'] = confidence_gradient_relationship
        else:
            features['confidence_gradient_rel'] = torch.zeros(num_nodes, device=self.device)
        
        # Feature 5: Topological Learning Role
        learning_role = self._compute_topological_learning_role(g, edge_index)
        features['learning_role'] = learning_role
        
        # Feature 6: Correction Receptiveness
        correction_receptiveness = self._compute_correction_receptiveness(
            g, edge_index, gradient_alignment, g_norm
        )
        features['correction_receptiveness'] = correction_receptiveness
        
        return features
    
    def _compute_conflict_intensity(self, g_norm, gradient_alignment):
        """
        Conflict Intensity: Measures how much a node's gradient disagrees with its neighborhood.
        Higher values indicate nodes that are learning in opposition to their context.
        
        Formula: intensity = (gradient_magnitude / avg_magnitude) * (1 - alignment)
        Range: [0, ∞), where 0=no conflict, >2=high conflict
        """
        avg_magnitude = g_norm.mean()
        magnitude_ratio = g_norm / (avg_magnitude + 1e-8)
        misalignment = torch.clamp(1 - gradient_alignment, 0, 2)  # Clamp for numerical stability
        conflict_intensity = magnitude_ratio * misalignment
        return conflict_intensity
    
    def _compute_trajectory_stability(self, g, edge_index):
        """
        Learning Trajectory Stability: Measures how consistent the gradient direction is
        across the local neighborhood. Stable nodes have similar gradient directions.
        
        Based on variance of gradient directions in the neighborhood.
        Range: [0, 1], where 1=very stable, 0=very unstable
        """
        num_nodes = g.size(0)
        
        # Normalize gradients to unit vectors
        g_normalized = F.normalize(g, p=2, dim=1)
        
        # For each node, collect neighbor gradient directions
        stability_scores = torch.zeros(num_nodes, device=self.device)
        
        for i in range(num_nodes):
            # Find neighbors
            neighbor_mask = (edge_index[0] == i) | (edge_index[1] == i)
            if neighbor_mask.any():
                neighbor_indices = torch.cat([
                    edge_index[1][edge_index[0] == i],
                    edge_index[0][edge_index[1] == i]
                ]).unique()
                
                if len(neighbor_indices) > 1:
                    # Compute pairwise cosine similarities between all neighbor gradients
                    neighbor_grads = g_normalized[neighbor_indices]
                    pairwise_sims = torch.mm(neighbor_grads, neighbor_grads.t())
                    
                    # Stability is the average pairwise similarity (excluding diagonal)
                    mask = ~torch.eye(len(neighbor_indices), dtype=torch.bool, device=self.device)
                    stability_scores[i] = pairwise_sims[mask].mean()
                else:
                    stability_scores[i] = 1.0  # Single neighbor = perfectly stable
            else:
                stability_scores[i] = 1.0  # No neighbors = perfectly stable (isolated)
        
        return torch.clamp(stability_scores, 0, 1)
    
    def _compute_multihop_influence(self, g, edge_index):
        """
        Multi-hop Influence Strength: Measures how much this node's gradient can influence
        nodes 2-3 hops away. High influence nodes are important for global learning.
        
        Uses gradient propagation through graph powers.
        Range: [0, ∞), normalized by global statistics
        """
        num_nodes = g.size(0)
        
        # Create adjacency matrix
        adj = torch.zeros(num_nodes, num_nodes, device=self.device)
        adj[edge_index[0], edge_index[1]] = 1.0
        
        # Compute 2-hop and 3-hop adjacency
        adj_2hop = torch.mm(adj, adj)
        adj_3hop = torch.mm(adj_2hop, adj)
        
        # Remove direct connections (we want multi-hop influence)
        adj_2hop = adj_2hop - adj
        adj_3hop = adj_3hop - adj_2hop - adj
        
        # Compute influence as weighted sum of gradient magnitudes at different hops
        g_magnitude = torch.norm(g, dim=1).unsqueeze(1)
        
        influence_2hop = torch.mm(adj_2hop, g_magnitude).squeeze()
        influence_3hop = torch.mm(adj_3hop, g_magnitude).squeeze() * 0.5  # Decay with distance
        
        total_influence = influence_2hop + influence_3hop
        
        # Normalize by node degree to avoid bias toward high-degree nodes
        degree = adj.sum(dim=1)
        normalized_influence = total_influence / (degree + 1)
        
        return normalized_influence
    
    def _compute_confidence_gradient_relationship(self, g_norm, predictions):
        """
        Prediction Confidence vs Gradient Magnitude: Analyzes the relationship between
        how confident the model is and how much it's trying to learn.
        
        Ideally: Low confidence → High gradient (learning needed)
                High confidence → Low gradient (already learned)
        
        Range: [-1, 1], where 1=ideal relationship, -1=problematic relationship
        """
        # Get prediction probabilities (confidence)
        if predictions.dim() == 1:
            # If predictions are class indices, create dummy confidence
            confidence = torch.rand(len(predictions), device=self.device)
        else:
            # If predictions are logits/probabilities
            confidence = F.softmax(predictions, dim=1).max(dim=1)[0]
        
        # Normalize gradient magnitude to [0,1] range
        g_norm_normalized = (g_norm - g_norm.min()) / (g_norm.max() - g_norm.min() + 1e-8)
        
        # Ideal relationship: high gradient when low confidence
        expected_gradient = 1 - confidence
        
        # Compute correlation between expected and actual gradient
        actual_grad = g_norm_normalized
        
        # Use negative correlation (high confidence should mean low gradient)
        relationship = -F.cosine_similarity(
            confidence.unsqueeze(0), 
            actual_grad.unsqueeze(0), 
            dim=1
        )[0]
        
        return relationship.expand(len(g_norm))
    
    def _compute_topological_learning_role(self, g, edge_index):
        """
        Topological Learning Role: Classifies each node's role in the learning process
        based on gradient patterns and graph structure.
        
        Roles:
        - Hub (2.0): High-degree nodes with strong gradients (learning centers)
        - Bridge (1.5): Nodes connecting different communities with moderate gradients
        - Follower (1.0): Nodes with gradients similar to neighbors (local learners)
        - Outlier (0.5): Isolated nodes or those with very different gradients
        
        Range: [0.5, 2.0]
        """
        num_nodes = g.size(0)
        
        # Compute node degrees
        degree = torch.zeros(num_nodes, device=self.device)
        degree.scatter_add_(0, edge_index[0], torch.ones(edge_index.size(1), device=self.device))
        
        # Compute gradient magnitudes
        g_norm = torch.norm(g, dim=1)
        
        # Compute local gradient similarity
        neighbor_gradients = scatter(g[edge_index[1]], edge_index[0], dim=0, reduce='mean')
        local_similarity = F.cosine_similarity(g, neighbor_gradients, dim=1)
        
        # Classify roles
        roles = torch.zeros(num_nodes, device=self.device)
        
        # High degree + strong gradient = Hub
        high_degree_mask = degree > torch.quantile(degree, 0.8)
        strong_gradient_mask = g_norm > torch.quantile(g_norm, 0.7)
        hub_mask = high_degree_mask & strong_gradient_mask
        roles[hub_mask] = 2.0
        
        # Medium degree + low similarity = Bridge
        medium_degree_mask = (degree > torch.quantile(degree, 0.4)) & (degree <= torch.quantile(degree, 0.8))
        low_similarity_mask = local_similarity < 0.5
        bridge_mask = medium_degree_mask & low_similarity_mask & ~hub_mask
        roles[bridge_mask] = 1.5
        
        # High similarity = Follower
        high_similarity_mask = local_similarity > 0.7
        follower_mask = high_similarity_mask & ~hub_mask & ~bridge_mask
        roles[follower_mask] = 1.0
        
        # Everything else = Outlier
        outlier_mask = ~(hub_mask | bridge_mask | follower_mask)
        roles[outlier_mask] = 0.5
        
        return roles
    
    def _compute_correction_receptiveness(self, g, edge_index, gradient_alignment, g_norm):
        """
        Correction Receptiveness: Predicts how amenable a node is to GGT corrections
        based on gradient patterns and neighborhood context.
        
        High receptiveness: Nodes that would benefit from gradient corrections
        Low receptiveness: Nodes that are already learning optimally
        
        Range: [0, 1], where 1=highly receptive, 0=not receptive
        """
        # Factors that increase receptiveness:
        # 1. High gradient magnitude (node is actively learning)
        # 2. Low neighbor alignment (conflicted learning)
        # 3. High local gradient variance (unstable learning)
        
        # Normalize gradient magnitude
        g_norm_normalized = (g_norm - g_norm.min()) / (g_norm.max() - g_norm.min() + 1e-8)
        
        # Measure local gradient variance
        local_variance = torch.zeros(g.size(0), device=self.device)
        for i in range(g.size(0)):
            neighbor_mask = (edge_index[0] == i) | (edge_index[1] == i)
            if neighbor_mask.any():
                neighbor_indices = torch.cat([
                    edge_index[1][edge_index[0] == i],
                    edge_index[0][edge_index[1] == i]
                ]).unique()
                
                if len(neighbor_indices) > 1:
                    neighbor_grads = g[neighbor_indices]
                    local_variance[i] = torch.var(torch.norm(neighbor_grads, dim=1))
        
        # Normalize variance
        if local_variance.max() > 0:
            local_variance = local_variance / local_variance.max()
        
        # Combine factors
        misalignment = torch.clamp(1 - gradient_alignment, 0, 1)
        
        receptiveness = (
            0.4 * g_norm_normalized +      # 40% gradient magnitude
            0.35 * misalignment +          # 35% neighbor misalignment  
            0.25 * local_variance          # 25% local instability
        )
        
        return torch.clamp(receptiveness, 0, 1)

class InterpretableGGT(nn.Module):
    """
    Enhanced GGT with gradient feature extraction for interpretability
    """
    
    def __init__(self, hidden_dim, num_heads=2, dropout=0.2):
        super(InterpretableGGT, self).__init__()
        self.layers = TransformerConv(hidden_dim, hidden_dim, heads=num_heads, concat=False, dropout=dropout)
        self.norm = nn.LayerNorm(hidden_dim)
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.scale = nn.Parameter(torch.tensor(1.0))
        
        # Feature extractor for interpretability
        self.feature_extractor = GradientFeatureExtractor(device=next(self.parameters()).device)
        
    def forward(self, g, edge_index, h=None, predictions=None, true_labels=None, epoch=0, extract_features=False):
        """
        Enhanced forward pass with optional feature extraction
        """
        logging.info("Running InterpretableGGT forward pass...")
        
        # Extract interpretability features if requested
        gradient_features = None
        if extract_features:
            gradient_features = self.feature_extractor.extract_features(
                g, h, edge_index, predictions, true_labels, epoch
            )
        
        # Standard GGT processing
        h_normalized = F.normalize(g, p=2, dim=1)
        h_normalized = self.norm(h_normalized)
        h_transformed, attn = self.layers(h_normalized, edge_index, return_attention_weights=True)
        h_combined = h_normalized + h_transformed  # Residual connection
        attn = (attn[0], torch.clamp(attn[1], min=1e-8, max=1.0))  # Clip attention weights
        delta_h = self.linear(h_combined)
        delta_h = self.scale * F.normalize(delta_h, p=2, dim=1)
        
        if extract_features:
            return delta_h, h_combined, attn, gradient_features
        else:
            return delta_h, h_combined, attn

class LLMExplainer:
    """
    Generates natural language explanations using gradient features and Groq LLM
    """
    
    def __init__(self, api_key, model_name="llama3-8b-8192"):
        self.client = Groq(api_key=api_key)
        self.model_name = model_name
        
    def create_node_context(self, node_idx, gradient_features, conflict_status=False, 
                          attention_weights=None, prediction_info=None):
        """
        Create interpretable context from gradient features for LLM
        """
        # Extract features for this node
        context = {
            "node_id": node_idx,
            "conflict_intensity": gradient_features['conflict_intensity'][node_idx].item(),
            "trajectory_stability": gradient_features['trajectory_stability'][node_idx].item(),
            "influence_strength": gradient_features['influence_strength'][node_idx].item(),
            "confidence_gradient_rel": gradient_features['confidence_gradient_rel'][node_idx].item(),
            "learning_role": gradient_features['learning_role'][node_idx].item(),
            "correction_receptiveness": gradient_features['correction_receptiveness'][node_idx].item(),
            "is_conflict_node": conflict_status
        }
        
        # Add attention information if available
        if attention_weights is not None:
            attn_edge_index, attn_weights = attention_weights
            node_attention_mask = (attn_edge_index[0] == node_idx) | (attn_edge_index[1] == node_idx)
            if node_attention_mask.any():
                node_attention = attn_weights[node_attention_mask]
                context["attention_focus"] = {
                    "max_attention": node_attention.max().item(),
                    "avg_attention": node_attention.mean().item(),
                    "attention_variance": node_attention.var().item()
                }
        
        # Add prediction information if available
        if prediction_info:
            context.update(prediction_info)
        
        return context
    
    def interpret_learning_role(self, role_value):
        """Convert numeric learning role to interpretable description"""
        if role_value >= 1.8:
            return "Hub (learning center)"
        elif role_value >= 1.3:
            return "Bridge (connector)"
        elif role_value >= 0.8:
            return "Follower (local learner)"
        else:
            return "Outlier (isolated)"
    
    def interpret_feature_levels(self, context):
        """Add interpretable descriptions to numeric features"""
        interpretations = {}
        
        # Conflict Intensity
        conflict = context['conflict_intensity']
        if conflict > 2.0:
            interpretations['conflict'] = "Very High - Strong disagreement with neighbors"
        elif conflict > 1.5:
            interpretations['conflict'] = "High - Moderate disagreement with neighbors"
        elif conflict > 1.0:
            interpretations['conflict'] = "Medium - Some disagreement with neighbors"
        else:
            interpretations['conflict'] = "Low - Good alignment with neighbors"
        
        # Trajectory Stability
        stability = context['trajectory_stability']
        if stability > 0.8:
            interpretations['stability'] = "Very Stable - Consistent learning direction"
        elif stability > 0.6:
            interpretations['stability'] = "Stable - Mostly consistent learning"
        elif stability > 0.4:
            interpretations['stability'] = "Moderate - Some learning inconsistency"
        else:
            interpretations['stability'] = "Unstable - Erratic learning patterns"
        
        # Influence Strength
        influence = context['influence_strength']
        if influence > 1.5:
            interpretations['influence'] = "Very High - Strong multi-hop influence"
        elif influence > 1.0:
            interpretations['influence'] = "High - Good multi-hop influence"
        elif influence > 0.5:
            interpretations['influence'] = "Medium - Moderate influence"
        else:
            interpretations['influence'] = "Low - Limited influence on distant nodes"
        
        # Correction Receptiveness
        receptiveness = context['correction_receptiveness']
        if receptiveness > 0.8:
            interpretations['receptiveness'] = "Very High - Excellent candidate for correction"
        elif receptiveness > 0.6:
            interpretations['receptiveness'] = "High - Good candidate for correction"
        elif receptiveness > 0.4:
            interpretations['receptiveness'] = "Medium - May benefit from correction"
        else:
            interpretations['receptiveness'] = "Low - Unlikely to benefit from correction"
        
        return interpretations
    
    def generate_explanation(self, context):
        """
        Generate natural language explanation from gradient features
        """
        interpretations = self.interpret_feature_levels(context)
        role_description = self.interpret_learning_role(context['learning_role'])
        
        # Build detailed context string
        context_str = f"""
        Node {context['node_id']} Analysis:
        
        GRADIENT FEATURES:
        • Conflict Intensity: {context['conflict_intensity']:.3f} - {interpretations['conflict']}
        • Learning Stability: {context['trajectory_stability']:.3f} - {interpretations['stability']}
        • Multi-hop Influence: {context['influence_strength']:.3f} - {interpretations['influence']}
        • Confidence-Gradient Relationship: {context['confidence_gradient_rel']:.3f}
        • Learning Role: {role_description}
        • Correction Receptiveness: {context['correction_receptiveness']:.3f} - {interpretations['receptiveness']}
        • Conflict Node Status: {'Yes' if context['is_conflict_node'] else 'No'}
        """
        
        if 'attention_focus' in context:
            context_str += f"""
        ATTENTION PATTERNS:
        • Max Attention Weight: {context['attention_focus']['max_attention']:.3f}
        • Average Attention: {context['attention_focus']['avg_attention']:.3f}
        • Attention Variance: {context['attention_focus']['attention_variance']:.3f}
        """
        
        if 'predicted_class' in context:
            context_str += f"""
        PREDICTIONS:
        • Predicted Class: {context['predicted_class']}
        • True Class: {context.get('true_class', 'Unknown')}
        • Prediction Correct: {context.get('is_correct', 'Unknown')}
        """
        
        # Create system prompt
        system_prompt = """You are an expert in Graph Neural Networks and gradient-based learning analysis. 
        Your task is to explain a node's learning behavior based on innovative gradient features extracted 
        from a Graph Gradient Transformer (GGT) model.

        Focus on:
        1. What the gradient features reveal about the node's learning state
        2. Why this node might be in conflict or harmony with its neighbors
        3. The node's role in the overall learning process
        4. Whether GGT corrections would be beneficial and why
        5. How attention patterns support or contradict the gradient analysis

        Provide technical but accessible explanations that help researchers understand the model's behavior."""
        
        # Create user prompt
        user_prompt = f"""
        Analyze this node's learning behavior based on the gradient features:
        
        {context_str}
        
        Please provide a comprehensive explanation covering:
        1. WHY this node shows these gradient patterns (conflict, stability, influence)
        2. WHAT ROLE this node plays in the graph's learning process
        3. HOW the GGT should handle this node (corrections needed?)
        4. IMPLICATIONS for model performance and interpretability
        
        Write as if explaining to a researcher studying graph neural network interpretability.
        Be specific about the technical insights while keeping it understandable.
        """
        
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=0.7,
                max_tokens=600,
                top_p=0.9
            )
            
            return response.choices[0].message.content.strip()
            
        except Exception as e:
            logging.error(f"Error generating explanation: {e}")
            return f"Error generating explanation: {str(e)}"

# Usage example
def demonstrate_interpretable_ggt():
    """
    Demonstrate the interpretable GGT with gradient feature extraction
    """
    print("🚀 Demonstrating Interpretable GGT with Gradient Features")
    print("=" * 60)
    
    # Load dataset
    dataset = Planetoid(root='/tmp/PubMed', name='PubMed')
    data = dataset[0]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = data.to(device)
    
    # Initialize interpretable GGT
    hidden_dim = 64
    ggt = InterpretableGGT(hidden_dim).to(device)
    
    # Create sample gradients and embeddings
    num_nodes = data.num_nodes
    g = torch.randn(num_nodes, hidden_dim, device=device)
    h = torch.randn(num_nodes, hidden_dim, device=device)
    predictions = torch.randn(num_nodes, 3, device=device)  # 3 classes for PubMed
    
    # Extract gradient features
    print("Extracting gradient features...")
    delta_h, h_combined, attention_weights, gradient_features = ggt(
        g, data.edge_index, h, predictions, data.y, epoch=1, extract_features=True
    )
    
    print(f"✅ Extracted {len(gradient_features)} gradient features for {num_nodes} nodes")
    
    # Print feature statistics
    for feature_name, feature_values in gradient_features.items():
        print(f"{feature_name}: mean={feature_values.mean():.3f}, std={feature_values.std():.3f}")
    
    # Initialize LLM explainer
    api_key = "your_groq_api_key_here"  # Your Groq API key
    explainer = LLMExplainer(api_key)
    
    # Select interesting nodes for explanation
    conflict_nodes = torch.where(gradient_features['conflict_intensity'] > 1.5)[0][:3]
    
    print(f"\n🤖 Generating explanations for {len(conflict_nodes)} high-conflict nodes...")
    
    for node_idx in conflict_nodes:
        print(f"\n--- Node {node_idx.item()} Analysis ---")
        
        # Create context
        context = explainer.create_node_context(
            node_idx.item(), 
            gradient_features, 
            conflict_status=True,
            attention_weights=attention_weights,
            prediction_info={
                'predicted_class': predictions[node_idx].argmax().item(),
                'true_class': data.y[node_idx].item(),
                'is_correct': predictions[node_idx].argmax().item() == data.y[node_idx].item()
            }
        )
        
        # Generate explanation
        explanation = explainer.generate_explanation(context)
        print(explanation)
        print("-" * 50)
    
    print("✅ Interpretable GGT demonstration completed!")

if __name__ == "__main__":
    demonstrate_interpretable_ggt() 