import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from Cfilter import CausalFilter

class GINModelWithCausalFilter(nn.Module):
    
    
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, 
                 task='graph', use_causal_filter=True, filter_config=None):
        super(GINModelWithCausalFilter, self).__init__()
        self.num_layers = num_layers
        self.task = task
        self.use_causal_filter = use_causal_filter
        
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        
        if use_causal_filter:
            self.causal_filters = nn.ModuleList()
            
            
            default_config = {
                'lambda_init': 1.0,
                'lambda_min': -2.0, 
                'decay_rate': 0.99,
                'temperature': 1.0,
                'residual_weight': 0.2,
                'normalize': False,
                'dropout': 0.1
            }
            
            
            input_config = filter_config.get('input', default_config) if filter_config else default_config
            self.causal_filters.append(CausalFilter(num_features, **input_config))
            
            
            for i in range(num_layers):
                layer_config = filter_config.get(f'layer_{i}', default_config) if filter_config else default_config
                self.causal_filters.append(CausalFilter(hidden_dim, **layer_config))
        
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch=None):
        
        if self.use_causal_filter:
            x = self.causal_filters[0](x)
        
        
        for i, conv in enumerate(self.gin_layers):
            x = F.relu(conv(x, edge_index))
            if self.use_causal_filter:
                x = self.causal_filters[i + 1](x)
        
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
    
    def step_filters(self):
        
        if self.use_causal_filter:
            for filter_module in self.causal_filters:
                filter_module.step()
    
    def get_filter_info(self):
        
        if not self.use_causal_filter:
            return "Causal filters are disabled"
        
        info = []
        for i, filter_module in enumerate(self.causal_filters):
            lambda_val = filter_module.get_current_lambda()
            info.append(f"Filter {i}: lambda={lambda_val:.4f}")
        return "\n".join(info)



class GINModel(nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, task='graph'):
        super(GINModel, self).__init__()
        self.num_layers = num_layers
        self.task = task
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch=None):
        for conv in self.gin_layers:
            x = F.relu(conv(x, edge_index))
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
