import re

from data_utils.data_corruption.corruption_type import remove_corruption_type_from_dataset_name


def remove_z_dim_from_dataset_name(dataset_name: str) -> str:
    z_dim_part = re.search(r'_z\d+', dataset_name)
    if z_dim_part is None:
        return dataset_name
    return dataset_name.replace(re.search(r'_z\d+', dataset_name).group(), "")

def remove_e_from_dataset_name(dataset_name: str) -> str:
    e_part = re.search(r'e_\d+_', dataset_name)
    if e_part is None:
        return dataset_name
    return dataset_name.replace(re.search(r'e_\d+_', dataset_name).group(), "")


def get_e_from_dataset_name(dataset_name: str) -> int:
    e_dim_part = re.search(r'e_\d+_', dataset_name)
    if e_dim_part is None:
        e = None
    else:
        e = int(e_dim_part.group().replace("e_", "").replace("_", ""))
    return e


def get_original_dataset_name(dataset_name: str) -> str:
    dataset_name = remove_e_from_dataset_name(dataset_name)
    dataset_name = remove_z_dim_from_dataset_name(dataset_name)
    dataset_name = remove_corruption_type_from_dataset_name(dataset_name)
    dataset_name = dataset_name.replace("adversarial_", "")
    return dataset_name


def get_z_dim_from_data_name(dataset_name: str):
    z_dim_part = re.search(r'_z\d+', dataset_name)
    if z_dim_part is None:
        z_dim = 1
    else:
        z_dim = int(z_dim_part.group().replace("_z", ""))
    return z_dim
