import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalFilter(nn.Module):
    
    def __init__(self, 
                 input_dim,
                 hidden_dim=64,
                 normalize=True,
                 tau=1.0):
        
        super(CausalFilter, self).__init__()
        self.normalize = normalize
        self.tau = tau
        
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  
        )
        self._init_weights()
        
        
        self.gate_log_prob = None
        self.gate_values = None

    def _init_weights(self):
        for module in self.mlp:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, x):
        
        
        gate_logits = self.mlp(x)  

        
        
        gate = F.gumbel_softmax(gate_logits, tau=self.tau, hard=True, dim=-1) 
        
        
        self.gate_values = gate[:, 1].unsqueeze(-1)  
        
        
        
        
        q_z = F.softmax(gate_logits, dim=-1) 
        p_z = torch.tensor([0.5, 0.5], device=x.device, dtype=torch.float)
        self.gate_log_prob = torch.sum(q_z * torch.log(q_z / p_z + 1e-8), dim=-1)

        
        filtered_x = x * self.gate_values  
        
        
        if self.normalize:
            filtered_x = F.layer_norm(filtered_x, (filtered_x.shape[-1],))
            
        return filtered_x

    def get_info_bottleneck_loss(self):
        
        if self.gate_log_prob is None:
            return 0.0
        
        
        return torch.mean(self.gate_log_prob)

    def get_gate_stats(self):
        
        if self.gate_values is None:
            return {}
        
        num_activated = torch.sum(self.gate_values > 0).item()
        total = self.gate_values.numel()
        activation_ratio = num_activated / total if total > 0 else 0
        
        return {
            'activation_ratio': activation_ratio,
            'mean_gate': self.gate_values.mean().item(),
        }


class CausalGINWrapper(nn.Module):
    
    def __init__(self, base_model, num_features, layer_dims=None, filter_hidden_dim=64, normalize_filter=True, tau=1.0):
        super(CausalGINWrapper, self).__init__()
        self.base_model = base_model
        
        
        if layer_dims is None:
            layer_dims = self._infer_layer_dims(num_features)
        
        
        self.filters = nn.ModuleList()
        for dim in layer_dims:
            self.filters.append(CausalFilter(dim, filter_hidden_dim, normalize_filter, tau))
        
        
        self.hooks = []
        self.activations = {}
        self.layer_counter = 0
        
        
        self._register_hooks()
    
    def _infer_layer_dims(self, num_features):
        
        layer_dims = [num_features]  
        
        
        if hasattr(self.base_model, 'gin_layers') and len(self.base_model.gin_layers) > 0:
            
            first_gin = self.base_model.gin_layers[0]
            if hasattr(first_gin, 'nn') and len(first_gin.nn) > 0:
                
                for layer in reversed(first_gin.nn):
                    if isinstance(layer, nn.Linear):
                        hidden_dim = layer.out_features
                        break
                else:
                    hidden_dim = 64  
            else:
                hidden_dim = 64  
            
            
            for _ in range(len(self.base_model.gin_layers)):
                layer_dims.append(hidden_dim)
        else:
            
            layer_dims.extend([64, 64])  
        
        return layer_dims
    
    def _register_hooks(self):
        
        layer_idx = 0
        
        
        if hasattr(self.base_model, 'gin_layers') and len(self.base_model.gin_layers) > 0:
            def input_hook(module, input, output):
                
                x = input[0] if isinstance(input, tuple) else input
                if layer_idx < len(self.filters):
                    filtered_x = self.filters[0](x)
                    return (filtered_x,) + input[1:] if isinstance(input, tuple) else filtered_x
                return input
            
            
            handle = self.base_model.gin_layers[0].register_forward_pre_hook(input_hook)
            self.hooks.append(handle)
        
        
        if hasattr(self.base_model, 'gin_layers'):
            for i, gin_layer in enumerate(self.base_model.gin_layers):
                def create_output_hook(filter_idx):
                    def output_hook(module, input, output):
                        if filter_idx + 1 < len(self.filters):
                            return self.filters[filter_idx + 1](output)
                        return output
                    return output_hook
                
                handle = gin_layer.register_forward_hook(create_output_hook(i))
                self.hooks.append(handle)

    def forward(self, x, edge_index, batch=None):
        
        return self.base_model(x, edge_index, batch)

    def get_total_info_bottleneck_loss(self):
        
        total_ib_loss = 0.0
        for f in self.filters:
            total_ib_loss += f.get_info_bottleneck_loss()
        return total_ib_loss / len(self.filters) if self.filters else 0.0 

    def get_all_gate_stats(self):
        
        stats = {}
        for i, f in enumerate(self.filters):
            stats[f'filter_layer_{i}'] = f.get_gate_stats()
        return stats
    
    def __del__(self):
        
        for handle in self.hooks:
            handle.remove()
