import os
import random

import numpy as np
import torch


def set_seed_torch(seed):
    if seed is not None:
        # Fix seeds for core components to ensure reproducibility
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)

        # Don't fix CUDA seeds to allow some randomness for model exploration
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # Allow for some optimization flexibility
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
