


import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from models_with_filter import GCNWithImprovedFilter  


def set_env(seed):
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def train_full_graph(model, data, optimizer, device):
    
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    
    
    out = model(data.x, data.edge_index)
    
    
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    
    pred = out.argmax(dim=1)
    correct = (pred[data.train_mask] == data.y[data.train_mask]).sum().item()
    total = data.train_mask.sum().item()
    
    return loss.item(), correct / total


def evaluate_full_graph(model, data, device, split_mask):
    
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct = (pred[split_mask] == data.y[split_mask]).sum().item()
        total = split_mask.sum().item()
    return correct / total


if __name__ == '__main__':
    print("Starting training with GCN on CiteSeer dataset:")
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs',     type=int,   default=1000)
    parser.add_argument('--lr',         type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_channels', type=int, default=64)
    parser.add_argument('--dropout',    type=float, default=0.5)
    parser.add_argument('--no_cuda',    action='store_true')
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Whether to use causal filter module')
    parser.add_argument('--runs',       type=int,   default=10, help='Number of runs')
    parser.add_argument('--seed',       type=int,   default=42, help='Random seed')
    parser.add_argument('--patience',   type=int,   default=100, help='Early stopping patience')
    args = parser.parse_args()

    
    set_env(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Using device: {device}")
    print(f"Causal filter module: {'Enabled' if args.use_causal_filter else 'Disabled'}")
    print(f"Hidden layer dimension: {args.hidden_channels}")

    
    test_accs = []
    val_accs = []
    
    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        
        
        set_env(args.seed + run)
        
        
        dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
        data = dataset[0]
        num_features = dataset.num_node_features
        num_classes = dataset.num_classes
        
        print(f"Dataset info: Nodes={data.x.size(0)}, Edges={data.edge_index.size(1)}, Feature dim={num_features}, Number of classes={num_classes}")
        print(f"Training nodes={data.train_mask.sum().item()}, Validation nodes={data.val_mask.sum().item()}, Testing nodes={data.test_mask.sum().item()}")
        
        
        filter_config = {
            'input': {
                'lambda_init': 10.0, 
                'decay_rate': 0.95,
                'hidden_dim': args.hidden_channels // 4,
                'normalize': False,
                'dropout': 0.1,
                'temperature': 1.0,
                'residual_weight': 0.2
            },
            'hidden': {
                'lambda_init': 10.0, 
                'decay_rate': 0.95,
                'hidden_dim': args.hidden_channels // 4,
                'normalize': False,
                'dropout': 0.1,
                'temperature': 1.0,
                'residual_weight': 0.2
            }
        }

        
        model = GCNWithImprovedFilter(
            num_features=num_features,
            num_classes=num_classes,
            hidden_channels=args.hidden_channels,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None,
            task='node'
        ).to(device)

        print(f"Model parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay
        )

        best_val = 0
        final_test_acc = 0
        patience_counter = 0
        
        for epoch in range(1, args.epochs + 1):
            
            loss, train_acc = train_full_graph(model, data, optimizer, device)
            
            
            val_acc = evaluate_full_graph(model, data, device, data.val_mask)
            test_acc = evaluate_full_graph(model, data, device, data.test_mask)

            
            if args.use_causal_filter:
                model.step_epoch()

            
            if val_acc > best_val:
                best_val = val_acc
                final_test_acc = test_acc
                patience_counter = 0
                
            else:
                patience_counter += 1
                if patience_counter >= args.patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
            
            
            if epoch % 100 == 0 or epoch == 1:
                print(f'Epoch {epoch:04d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter and epoch % 200 == 0:
                    print(model.get_filter_info())

        test_accs.append(final_test_acc * 100)
        val_accs.append(best_val * 100)
        print(f'Run {run + 1}: Best Val Acc: {best_val:.4f}, Test Acc: {final_test_acc:.4f}')

    
    test_mean = np.mean(test_accs)
    test_std = np.std(test_accs)
    val_mean = np.mean(val_accs)
    val_std = np.std(val_accs)
    
    print(f'\n=== Final Results ===')
    print(f'Validation Accuracy: {val_mean:.2f} ± {val_std:.2f}%')
    print(f'Test Accuracy: {test_mean:.2f} ± {test_std:.2f}%')
    print(f'Individual test runs: {[f"{acc:.2f}" for acc in test_accs]}')
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    print(f'\n=== Final Results of GCN on CiteSeer Dataset ===')
    if args.use_causal_filter:
        print(f'GCN + ImprovedCausalFilter: Test Accuracy = {test_mean:.2f} ± {test_std:.2f}%')
    else:
        print(f'GCN (baseline): Test Accuracy = {test_mean:.2f} ± {test_std:.2f}%')
