import sys, time
sys.path.append("NL-Augmenter")

# pip install spacy torchtext cucco fastpunct sacremoses
# python -m spacy download en_core_web_sm


from nlaugmenter.transformations.butter_fingers_perturbation.transformation import ButterFingersPerturbation
from nlaugmenter.transformations.random_deletion.transformation import RandomDeletion
from nlaugmenter.transformations.synonym_substitution.transformation import SynonymSubstitution
from nlaugmenter.transformations.back_translation.transformation import BackTranslation
from nlaugmenter.transformations.change_char_case.transformation import ChangeCharCase
from nlaugmenter.transformations.whitespace_perturbation.transformation import WhitespacePerturbation
from nlaugmenter.transformations.underscore_trick.transformation import UnderscoreTrick
from nlaugmenter.transformations.style_paraphraser.transformation import StyleTransferParaphraser
from nlaugmenter.transformations.punctuation.transformation import PunctuationWithRules




def get_aug_generator(aug_style):

    if aug_style == "butter_fingers":
        trans = ButterFingersPerturbation(max_outputs=1)
        return trans
    
    elif aug_style == "random_deletion":
        trans = RandomDeletion(prob=0.25)
        return trans

    elif aug_style == "synonym_substitution":
        trans = SynonymSubstitution(max_outputs=1, prob = 0.2)
        return trans

    elif aug_style == "back_translation":
        trans = BackTranslation()
        return trans

    elif aug_style == "change_char_case":
        trans = ChangeCharCase()
        return trans

    elif aug_style == "whitespace_perturbation":
        trans = WhitespacePerturbation()
        return trans

    elif aug_style == "underscore_trick":
        trans = UnderscoreTrick(prob = 0.25)
        return trans

    elif aug_style == "style_paraphraser":
        trans = StyleTransferParaphraser(style = "Basic", upper_length="same_5")
        return trans

    elif aug_style == "punctuation_perturbation":
        normalizations = ['remove_extra_white_spaces', ('replace_characters', {'characters': 'was', 'replacement': 'TZ'}),
                      ('replace_emojis', {'replacement': 'TESTO'})]
        trans = PunctuationWithRules(rules=normalizations)
        return trans
        
    else:
        raise ValueError("Augmentation style not found. Please check the available styles.")

# def generate_perturbations(text_list):
#     augmentation_styles = ["synonym_substitution", "butter_fingers", "random_deletion", "change_char_case", "whitespace_perturbation", "underscore_trick"]
#     all_augmented = {}
#     for style in augmentation_styles:
#         start = time.time()
#         aug_list = aug_generator(text_list, style)
#         all_augmented[style] = aug_list
#         print(f"Perturbing with {style} took {time.time() - start} seconds")
#     return all_augmented
        
# if __name__ == "__main__":
#     text_list = ["This is a test sentence. It is a good sentence.", "This is another test sentence. It is a bad sentence."]
#     print(generate_perturbations(text_list))