import numpy as np
import random
import torch
import selfies as sf
from selfies.exceptions import EncoderError
from datasets import Dataset
import tqdm
import yaml

def set_random_seed(seed_value=42):
    seed_value = 42

    # Set the seed for Python's random module
    random.seed(seed_value)

    # Set the seed for NumPy
    np.random.seed(seed_value)

    # Set the seed for PyTorch
    torch.manual_seed(seed_value)

def smiles2selfies(data, smiles_name="SMILES", selfies_name="SELFIES"):
    data_dict = data.to_list()
    new_samples = []
    for sample in tqdm.tqdm(data_dict):
        if sample[smiles_name] is None:
            sample[selfies_name] = None
        else:
            try:
                sample[selfies_name] = sf.encoder(sample[smiles_name])
            except EncoderError:
                sample[selfies_name] = ""
        new_samples.append(sample)
    return Dataset.from_list(new_samples)

def load_yaml(path):
    with open(path, "r") as stream:
        return yaml.safe_load(stream)

def merge_mod(params, mod_args):
    for i in range(0, len(mod_args), 2):
        if mod_args[i + 1].isdigit():
            val = int(mod_args[i + 1])
        elif mod_args[i + 1].replace(".", "", 1).isdigit():
            val = float(mod_args[i + 1])
        elif mod_args[i + 1].lower() == "true":
            val = True
        elif mod_args[i + 1].lower() == "false":
            val = False
        else:
            val = mod_args[i + 1]
        params[mod_args[i]] = val
    return params