import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from improved_filter import ImprovedCausalFilter

class GINWithImprovedFilter(nn.Module):
    
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, 
                 task='graph', use_causal_filter=True, filter_config=None):
        super(GINWithImprovedFilter, self).__init__()
        self.num_layers = num_layers
        self.task = task
        self.use_causal_filter = use_causal_filter
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        if use_causal_filter:
            self.causal_filters = nn.ModuleList()
            default_config = {'lambda_init': 10.0, 'decay_rate': 0.95, 'normalize': False}
            
            input_config = (filter_config.get('input', default_config) if filter_config 
                            else default_config)
            self.causal_filters.append(ImprovedCausalFilter(num_features, **input_config))
            
            for i in range(num_layers):
                layer_config = (filter_config.get(f'layer_{i}', default_config) if filter_config 
                                else default_config)
                self.causal_filters.append(ImprovedCausalFilter(hidden_dim, **layer_config))
        
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch=None):
        if self.use_causal_filter:
            x = self.causal_filters[0](x)
        
        for i, conv in enumerate(self.gin_layers):
            x = F.relu(conv(x, edge_index))
            if self.use_causal_filter:
                x = self.causal_filters[i + 1](x)
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
    
    def step_epoch(self):
        
        if self.use_causal_filter:
            for filter_module in self.causal_filters:
                filter_module.step()
    
    def get_filter_info(self):
        
        if not self.use_causal_filter:
            return "Causal filters are disabled."
        
        info = []
        for i, filter_module in enumerate(self.causal_filters):
            lambda_val = filter_module.get_current_lambda()
            gate_stats = filter_module.get_stats()
            info.append(f"Filter {i}: λ={lambda_val:.4f}, Gate Stats={gate_stats}")
        return "\n".join(info)
