
import random
import dgl
import numpy as np
import torch

import os
# os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"

def set_seed(seed, device = None):
    # random seed
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.set_num_threads(1)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # torch.manual_seed_all(seed)
    dgl.seed(seed)
    dgl.random.seed(seed)
    if device == 'cuda':
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)
    
    torch.set_num_threads(1) # This prevents multithreading in pytorch, as it can lead to wait condition failures if used with multiple cores when running parallel operations
    torch.manual_seed(seed)