import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from improved_filter import ImprovedCausalFilter, AdaptiveCausalFilter

class GINModelWithImprovedFilter(nn.Module):
    
    
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, 
                 task='graph', use_causal_filter=True, filter_type='improved', 
                 filter_config=None):
        super(GINModelWithImprovedFilter, self).__init__()
        self.num_layers = num_layers
        self.task = task
        self.use_causal_filter = use_causal_filter
        self.filter_type = filter_type
        
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        
        if use_causal_filter:
            self.causal_filters = nn.ModuleList()
            
            if filter_type == 'improved':
                
                default_config = {
                    'lambda_init': 1.0,
                    'lambda_min': -1.0,
                    'decay_rate': 0.99,
                    'temperature': 1.0,
                    'residual_weight': 0.2,
                    'normalize': False,
                    'dropout': 0.1
                }
                
                
                input_config = {**default_config, **(filter_config.get('input', {}) if filter_config else {})}
                self.causal_filters.append(ImprovedCausalFilter(num_features, **input_config))
                
                
                for i in range(num_layers):
                    layer_config = {**default_config, **(filter_config.get(f'layer_{i}', {}) if filter_config else {})}
                    self.causal_filters.append(ImprovedCausalFilter(hidden_dim, **layer_config))
                    
            elif filter_type == 'adaptive':
                
                default_config = {
                    'warmup_epochs': 10,
                    'max_epochs': 50,
                    'dropout': 0.1,
                    'residual_weight': 0.1
                }
                
                
                input_config = {**default_config, **(filter_config.get('input', {}) if filter_config else {})}
                self.causal_filters.append(AdaptiveCausalFilter(num_features, **input_config))
                
                
                for i in range(num_layers):
                    layer_config = {**default_config, **(filter_config.get(f'layer_{i}', {}) if filter_config else {})}
                    self.causal_filters.append(AdaptiveCausalFilter(hidden_dim, **layer_config))
        
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch=None):
        
        if self.use_causal_filter:
            x = self.causal_filters[0](x)
        
        
        for i, conv in enumerate(self.gin_layers):
            x = F.relu(conv(x, edge_index))
            if self.use_causal_filter:
                x = self.causal_filters[i + 1](x)
        
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
    
    def step_filters(self):
        
        if self.use_causal_filter:
            for filter_module in self.causal_filters:
                filter_module.step()
    
    def get_filter_info(self):
        
        if not self.use_causal_filter:
            return "Causal filters are disabled"
        
        info = []
        for i, filter_module in enumerate(self.causal_filters):
            if hasattr(filter_module, 'get_stats'):
                stats = filter_module.get_stats()
                if self.filter_type == 'improved':
                    gate_stats = stats['gate_stats']
                    info.append(f"Filter {i}: λ={stats['lambda']:.3f}, gate=[{gate_stats['min']:.3f}, {gate_stats['max']:.3f}], mean={gate_stats['mean']:.3f}")
                elif self.filter_type == 'adaptive':
                    info.append(f"Filter {i}: strength={stats['filter_strength']:.3f}, phase={stats['phase']}")
            else:
                info.append(f"Filter {i}: no stats available")
        return "\n".join(info)
    
    def reset_filters(self):
        
        if self.use_causal_filter:
            for filter_module in self.causal_filters:
                if hasattr(filter_module, 'reset_lambda'):
                    filter_module.reset_lambda()



class GINModel(nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, task='graph'):
        super(GINModel, self).__init__()
        self.num_layers = num_layers
        self.task = task
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch=None):
        for conv in self.gin_layers:
            x = F.relu(conv(x, edge_index))
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
