import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from Cfilter import CausalFilter

class GINModel(nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, task='graph', use_causal_filter=False, filter_config=None):
        super(GINModel, self).__init__()
        self.num_layers = num_layers
        self.task = task
        self.use_causal_filter = use_causal_filter
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        if self.use_causal_filter:
            self.causal_filters = nn.ModuleList()
            
            default_config = {
                'lambda_init': 10.0, 'lambda_min': 0.1, 'decay_rate': 0.99, 'normalize': True
            }
            
            input_config = (filter_config.get('input', default_config) if filter_config else default_config)
            self.causal_filters.append(CausalFilter(num_features, **input_config))
            
            for i in range(num_layers):
                layer_config = (filter_config.get(f'layer_{i}', default_config) if filter_config else default_config)
                self.causal_filters.append(CausalFilter(hidden_dim, **layer_config))

        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch=None):
        if self.use_causal_filter:
            x = self.causal_filters[0](x)

        for i, conv in enumerate(self.gin_layers):
            x = F.relu(conv(x, edge_index))
            if self.use_causal_filter:
                x = self.causal_filters[i + 1](x)
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        elif self.task == 'node':
            pass  
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

    def step_filters(self):
        
        if self.use_causal_filter:
            for filter_module in self.causal_filters:
                filter_module.step()
