import os, random
import torch

import numpy as np

def set_seed(seed=1) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def smart_dir(dir_name, base_list = None):
    dir_name = dir_name + '/'
    if base_list is None:
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
        return dir_name
    else:
        dir_names = []
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
        for d in range(len(base_list)):
            dir_names.append(dir_name + base_list[d] + '/')
            if not os.path.exists(dir_names[d]):
                os.makedirs(dir_names[d])
        return dir_names
    

def config_valididty_check(config):
    if config.agent.name == 'LUMP':
        assert (config.agent.mem_batch_size == config.scenario.batch_size)

    if config.agent.loss == 'simsiam':
        assert config.agent.n_views == 2
    
    if config.agent.loss == 'vicreg':
        assert config.agent.n_views == 2