import numpy as np
import random
import argparse
import torch

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_args():
    parser = argparse.ArgumentParser(description='DiP-G Implementation')
    parser.add_argument('--dataset_name', type=str, default='PubMed', help='Dataset name')
    parser.add_argument('--pretrain_task', type=str, default='EdgePredGraphPrompt', help='pre-training task')
    parser.add_argument('--shots', type=int, default=5, help='number of shots')
    parser.add_argument('--hidden_dim', type=int, default=128, help='hidden_dim')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
    parser.add_argument('--epochs', type=int, default=500, help='epochs')
    parser.add_argument('--gpu_id', type=int, default=0, help='GPU device ID')
    
    # DiP-G specific args
    parser.add_argument('--k', type=int, default=8, help='sparsity k')
    parser.add_argument('--m', type=int, default=128, help='candidate budget m')
    parser.add_argument('--r', type=int, default=4, help='screening buffer r')
    
    args = parser.parse_args()
    return args