import torch
import random
import numpy as np


def set_seed(seed):
    """
    Sets the random seed for torch, numpy, and random.
    Returns the current random state so it can be restored later.
    """
    current_state = {
        "torch": torch.get_rng_state(),
        "torch_cuda": torch.cuda.get_rng_state_all()
        if torch.cuda.is_available()
        else None,
        "numpy": np.random.get_state(),
        "random": random.getstate(),
    }

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    return current_state


def restore_seed(state):
    """
    Restores the random state from the provided state.
    """
    torch.set_rng_state(state["torch"])
    if torch.cuda.is_available():
        torch.cuda.set_rng_state_all(state["torch_cuda"])
    np.random.set_state(state["numpy"])
    random.setstate(state["random"])
