from . import A2Aug, AutoTSA, RandAugment, NoAug, TrivialAugment, UniformAugment, AutoTCL, InfoTS



AVAILABLE_TSA={
    e.__name__.split('.')[-1]:e.Model for e in [A2Aug, AutoTSA, RandAugment, NoAug, TrivialAugment, UniformAugment, AutoTCL, InfoTS]
}

def get_auto_augment_class(tsa:str):
    if tsa not in AVAILABLE_TSA:
        raise NotImplementedError(f"Unknown TSA:{tsa}(options:{list(AVAILABLE_TSA.keys())})")
    return AVAILABLE_TSA[tsa]
