


import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_scipy_sparse_matrix
import scipy.sparse as sp


from models_with_filter import ChebyNetWithFilter
from script import utility


def set_env(seed):
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def process_citeseer_data_for_chebnet(data, device, gso_type='sym_norm_lap'):
    
    num_nodes = data.x.size(0)
    
    
    adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=num_nodes)
    
    
    gso = utility.calc_gso(adj, gso_type)
    gso = utility.calc_chebynet_gso(gso)
    
    
    return utility.cnv_sparse_mat_to_coo_tensor(gso, device)


def train_full_graph(model, data, gso, optimizer, device):
    
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    
    
    out = model(data.x, gso)
    
    
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    
    pred = out.argmax(dim=1)
    correct = (pred[data.train_mask] == data.y[data.train_mask]).sum().item()
    total = data.train_mask.sum().item()
    
    return loss.item(), correct / total


def evaluate_full_graph(model, data, gso, device, split_mask):
    
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        out = model(data.x, gso)
        pred = out.argmax(dim=1)
        correct = (pred[split_mask] == data.y[split_mask]).sum().item()
        total = split_mask.sum().item()
    return correct / total


if __name__ == '__main__':
    print("Starting training with ChebNet on CiteSeer dataset:")
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs',     type=int,   default=1000)
    parser.add_argument('--lr',         type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_dim', type=int,   default=64)
    parser.add_argument('--Ko',         type=int,   default=2, help='Chebyshev polynomial order')
    parser.add_argument('--Kl',         type=int,   default=2, help='Number of ChebNet layers')
    parser.add_argument('--droprate',   type=float, default=0.5)
    parser.add_argument('--enable_bias', type=bool, default=True)
    parser.add_argument('--gso_type',   type=str,   default='sym_norm_lap', choices=['sym_norm_lap', 'rw_norm_lap'])
    parser.add_argument('--no_cuda',    action='store_true')
    parser.add_argument('--use_causal_filter', action='store_true', default=False, help='Whether to use causal filter module')
    parser.add_argument('--runs',       type=int,   default=10, help='Number of runs')
    parser.add_argument('--seed',       type=int,   default=42, help='Random seed')
    parser.add_argument('--patience',   type=int,   default=100, help='Early stopping patience')
    args = parser.parse_args()

    
    set_env(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Using device: {device}")
    print(f"Causal filter module: {'Enabled' if args.use_causal_filter else 'Disabled'}")
    print(f"Chebyshev polynomial order: {args.Ko}, Number of ChebNet layers: {args.Kl}")

    
    test_accs = []
    val_accs = []
    
    for run in range(args.runs):
        print(f"\n=== Run {run + 1} ===")
        
        
        set_env(args.seed + run)
        
        
        dataset = Planetoid(root='./data/Planetoid', name='CiteSeer', transform=NormalizeFeatures())
        data = dataset[0]
        num_features = dataset.num_node_features
        num_classes = dataset.num_classes
        
        print(f"Dataset info: Nodes={data.x.size(0)}, Edges={data.edge_index.size(1)}, Feature dimensions={num_features}, Number of classes={num_classes}")
        
        
        gso = process_citeseer_data_for_chebnet(data, device, args.gso_type)
        
        
        filter_config = {
            'input': {
                'lambda_init': 10.0, 
                'decay_rate': 0.95,
                'hidden_dim': args.hidden_dim // 4,
                'normalize': False,
                'dropout': 0.1
            },
            'hidden': {
                'lambda_init': 10.0, 
                'decay_rate': 0.95,
                'hidden_dim': args.hidden_dim // 4,
                'normalize': False,
                'dropout': 0.1
            }
        }

        
        model = ChebyNetWithFilter(
            n_feat=num_features,
            n_hid=args.hidden_dim,
            n_class=num_classes,
            enable_bias=args.enable_bias,
            Ko=args.Ko,
            Kl=args.Kl,
            droprate=args.droprate,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None,
            task='node'
        ).to(device)

        print(f"Model parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay
        )

        best_val = 0
        final_test_acc = 0
        patience_counter = 0
        
        for epoch in range(1, args.epochs + 1):
            
            loss, train_acc = train_full_graph(model, data, gso, optimizer, device)
            
            
            val_acc = evaluate_full_graph(model, data, gso, device, data.val_mask)
            test_acc = evaluate_full_graph(model, data, gso, device, data.test_mask)

            
            if args.use_causal_filter:
                model.step_epoch()

            
            if val_acc > best_val:
                best_val = val_acc
                final_test_acc = test_acc
                patience_counter = 0
                
            else:
                patience_counter += 1
                if patience_counter >= args.patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
            
            
            if epoch % 100 == 0 or epoch == 1:
                print(f'Epoch {epoch:04d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter and epoch % 200 == 0:
                    print(model.get_filter_info())

        test_accs.append(final_test_acc * 100)
        val_accs.append(best_val * 100)
        print(f'Run {run + 1}: Best Val Acc: {best_val:.4f}, Test Acc: {final_test_acc:.4f}')

    
    test_mean = np.mean(test_accs)
    test_std = np.std(test_accs)
    val_mean = np.mean(val_accs)
    val_std = np.std(val_accs)
    
    print(f'\n=== Final Results ===')
    print(f'Validation Accuracy: {val_mean:.2f} ± {val_std:.2f}%')
    print(f'Test Accuracy: {test_mean:.2f} ± {test_std:.2f}%')
    print(f'Individual test runs: {[f"{acc:.2f}" for acc in test_accs]}')
    
    
    
    
    
    
    
    
    
    
    
    
    
    
