import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from improved_filter import ImprovedCausalFilter

class GCNWithImprovedFilter(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=64, use_causal_filter=False, filter_config=None, task='node'):
        super(GCNWithImprovedFilter, self).__init__()
        self.use_causal_filter = use_causal_filter
        self.task = task

        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
        
        if self.use_causal_filter and filter_config:
            self.input_filter = ImprovedCausalFilter(num_features, **filter_config.get('input', {}))
            self.hidden_filter = ImprovedCausalFilter(hidden_channels, **filter_config.get('hidden', {}))
            self.filters = [self.input_filter, self.hidden_filter]
        else:
            self.filters = []

    def forward(self, x, edge_index, batch=None):
        if self.use_causal_filter:
            x = self.input_filter(x)
            
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        if self.use_causal_filter:
            x = self.hidden_filter(x)
            
        x = self.conv2(x, edge_index)

        if self.task == 'graph':
             from torch_geometric.nn import global_mean_pool
             x = global_mean_pool(x, batch)

        return F.log_softmax(x, dim=1)

    def step_epoch(self):
        if self.use_causal_filter:
            for f in self.filters:
                f.step()

    def get_filter_info(self):
        if not self.use_causal_filter:
            return "Causal filter not in use."
        
        info_str = "Filter Stats:\n"
        stats1 = self.input_filter.get_stats()
        stats2 = self.hidden_filter.get_stats()
        
        info_str += f"  Input Filter:  lambda={stats1['lambda']:.4f}, gate_mean={stats1['gate_stats']['mean']:.4f}\n"
        info_str += f"  Hidden Filter: lambda={stats2['lambda']:.4f}, gate_mean={stats2['gate_stats']['mean']:.4f}"
        return info_str
