model_to_length_dict = {
    'hyenadna-tiny-1k-seqlen': 1024,
    'hyenadna-small-32k-seqlen': 32768,
    'hyenadna-medium-160k-seqlen': 160000,
    'hyenadna-medium-450k-seqlen': 450000,  # T4 up to here
    'hyenadna-large-1m-seqlen': 1_000_000,  # only A100 (paid tier)
}

dataset_to_length_dict = {
    "demo_human_or_worm": 200, 
    "human_enhancers_cohn": 500,
}

dataset_to_label_list = {
    "demo_human_or_worm": ["human", "worm"], 
    "human_enhancers_cohn": ["positive", "negative"],
}

dataset_to_max_data_idx = {
    "demo_human_or_worm": 49999, 
    "human_enhancers_cohn": 3473,
}

model_to_emb_dim = {
    'hyenadna-tiny-1k-seqlen': 128,
    'hyenadna-small-32k-seqlen': 256,

}

def get_n_class(dataset_name):
    if dataset_name == "human_ensembl_regulatory":
        return 3
    else:
        return 2

def get_data_filename(dataset_name, type, label, idx):
    if dataset_name == "human_enhancers_cohn":
        return f"dataset/{dataset_name}/{type}/{label}/{type}_{label}_{idx}.txt"
    elif dataset_name == "demo_human_or_worm":
        return f"dataset/{dataset_name}/{type}/{label}/{idx}.txt"
    else:
        raise NotImplementedError