

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool
from improved_filter import CausalFilter, CausalFilterWrapper, add_causal_filter_to_model, create_filter_config





class GINModelWithFilter(nn.Module):
    
    
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, 
                 task='graph', use_causal_filter=True, filter_config=None):
        super(GINModelWithFilter, self).__init__()
        self.num_layers = num_layers
        self.task = task
        self.use_causal_filter = use_causal_filter
        
        
        self.gin_layers = nn.ModuleList()
        self.gin_layers.append(GINConv(nn.Sequential(
            nn.Linear(num_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )))
        for _ in range(1, num_layers):
            self.gin_layers.append(GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
            )))
        
        
        if use_causal_filter:
            self.causal_filters = nn.ModuleList()
            
            self.causal_filters.append(CausalFilter(
                num_features, 
                **(filter_config.get('input', {}) if filter_config else {})
            ))
            
            for i in range(num_layers):
                self.causal_filters.append(CausalFilter(
                    hidden_dim,
                    **(filter_config.get(f'layer_{i}', {}) if filter_config else {})
                ))
        
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, batch=None):
        
        if self.use_causal_filter:
            x = self.causal_filters[0](x)
        
        
        for i, conv in enumerate(self.gin_layers):
            x = F.relu(conv(x, edge_index))
            if self.use_causal_filter:
                x = self.causal_filters[i + 1](x)
        
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
    
    def step_filters(self):
        
        if self.use_causal_filter:
            for filter_module in self.causal_filters:
                filter_module.step()






@add_causal_filter_to_model
class GCNModel(nn.Module):
    
    
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=2, task='graph'):
        super(GCNModel, self).__init__()
        self.num_layers = num_layers
        self.task = task
        
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_features, hidden_dim))
        for _ in range(1, num_layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, batch=None):
        
        x = self.apply_causal_filter(x, 'input')
        
        
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            x = self.apply_causal_filter(x, f'layer_{i}')
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)






class OriginalGATModel(nn.Module):
    
    
    def __init__(self, num_features, num_classes, hidden_dim=64, heads=4, num_layers=2, task='graph'):
        super(OriginalGATModel, self).__init__()
        self.num_layers = num_layers
        self.task = task
        
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(num_features, hidden_dim, heads=heads, concat=False))
        for _ in range(1, num_layers):
            self.convs.append(GATConv(hidden_dim, hidden_dim, heads=heads, concat=False))
        
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, batch=None):
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        
        if self.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


class GATModelWithFilterWrapper(nn.Module):
    
    
    def __init__(self, num_features, num_classes, hidden_dim=64, heads=4, num_layers=2, task='graph'):
        super(GATModelWithFilterWrapper, self).__init__()
        
        
        self.gat_model = OriginalGATModel(num_features, num_classes, hidden_dim, heads, num_layers, task)
        
        
        layer_dims = [num_features] + [hidden_dim] * num_layers
        self.filter_wrapper = CausalFilterWrapper(
            self.gat_model, 
            layer_dims,
            filter_positions=[0, 1, 2],  
            lambda_init=100.0,
            lambda_min=0.1,
            decay_rate=0.95
        )
    
    def forward(self, x, edge_index, batch=None):
        
        x = self.filter_wrapper.apply_filter(x, 0)
        
        
        for i, conv in enumerate(self.gat_model.convs):
            x = F.relu(conv(x, edge_index))
            x = self.filter_wrapper.apply_filter(x, i + 1)
        
        
        if self.gat_model.task == 'graph':
            x = global_mean_pool(x, batch)
        
        x = self.gat_model.fc(x)
        return F.log_softmax(x, dim=1)
    
    def step_filters(self):
        
        self.filter_wrapper.step_all_filters()






def example_usage():
    
    
    num_features, num_classes = 128, 10
    
    print("=== Method 1: Direct integration of CausalFilter ===")
    
    filter_config = {
        'input': {'lambda_init': 100.0, 'lambda_min': 0.1, 'decay_rate': 0.95},
        'layer_0': {'lambda_init': 80.0, 'lambda_min': 0.1, 'decay_rate': 0.9},
        'layer_1': {'lambda_init': 60.0, 'lambda_min': 0.1, 'decay_rate': 0.85}
    }
    
    model1 = GINModelWithFilter(
        num_features, num_classes, 
        use_causal_filter=True, 
        filter_config=filter_config
    )
    print(f"GIN model parameter count: {sum(p.numel() for p in model1.parameters())}")
    
    print("\n=== Method 2: Using decorator ===")
    
    layer_dims = {'input': num_features, 'layer_0': 64, 'layer_1': 64}
    filter_config2 = create_filter_config(layer_dims)
    
    model2 = GCNModel(
        num_features, num_classes,
        use_causal_filter=True,
        filter_config=filter_config2
    )
    print(f"GCN model parameter count: {sum(p.numel() for p in model2.parameters())}")
    
    print("\n=== Method 3: Using Wrapper ===")
    model3 = GATModelWithFilterWrapper(num_features, num_classes)
    print(f"GAT model parameter count: {sum(p.numel() for p in model3.parameters())}")
    
    
    x = torch.randn(100, num_features)
    edge_index = torch.randint(0, 100, (2, 200))
    batch = torch.zeros(100, dtype=torch.long)
    
    print("\n=== Testing forward propagation ===")
    with torch.no_grad():
        out1 = model1(x, edge_index, batch)
        out2 = model2(x, edge_index, batch)
        out3 = model3(x, edge_index, batch)
        
        print(f"GIN output shape: {out1.shape}")
        print(f"GCN output shape: {out2.shape}")
        print(f"GAT output shape: {out3.shape}")


def training_example():
    
    
    num_features, num_classes = 64, 2
    
    
    model = GINModelWithFilter(
        num_features, num_classes,
        use_causal_filter=True
    )
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    model.train()
    
    print("=== Training example ===")
    for epoch in range(5):
        
        x = torch.randn(50, num_features)
        edge_index = torch.randint(0, 50, (2, 100))
        batch = torch.zeros(50, dtype=torch.long)
        y = torch.randint(0, num_classes, (1,))
        
        
        optimizer.zero_grad()
        out = model(x, edge_index, batch)
        loss = F.nll_loss(out, y)
        
        
        loss.backward()
        optimizer.step()
        
        
        model.step_filters()
        
        
        current_lambdas = []
        for i, filter_module in enumerate(model.causal_filters):
            current_lambda = filter_module.get_current_lambda()
            current_lambdas.append(current_lambda.item())
        
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, "
              f"Lambda values: {[f'{l:.2f}' for l in current_lambdas]}")


if __name__ == "__main__":
    example_usage()
    print("\n" + "="*50)
    training_example()
