import numpy as np

def get_llm_prompt_templates():
    prompt_templates_llama = [
        "{}.",
        "Describe the concept of {}.",
        "An image of {}.",
        "Describe {}.",
        "Describe visually a {}.",
        "Describe an image of {}.",
        "Describe visually an image of {}.",
        "Describe specifically and visually an image of {}.",
        "How to determine visually an {}.",
        "Let's describe an image of {}.",
        "Let's explain an image of {}.",
        "I am seeing an image of {}."
    ]
    return prompt_templates_llama

def get_prompt_templates():
    prompt_templates = [
        '{}.',
        'a photo of a {}.',
        'a bad photo of a {}.',
        'a photo of many {}.',
        'a sculpture of a {}.',
        'a photo of the hard to see {}.',
        'a low resolution photo of the {}.',
        'a rendering of a {}.',
        'graffiti of a {}.',
        'a bad photo of the {}.',
        'a cropped photo of the {}.',
        'a tattoo of a {}.',
        'the embroidered {}.',
        'a photo of a hard to see {}.',
        'a bright photo of a {}.',
        'a photo of a clean {}.',
        'a photo of a dirty {}.',
        'a dark photo of the {}.',
        'a drawing of a {}.',
        'a photo of my {}.',
        'the plastic {}.',
        'a photo of the cool {}.',
        'a close-up photo of a {}.',
        'a black and white photo of the {}.',
        'a painting of the {}.',
        'a painting of a {}.',
        'a pixelated photo of the {}.',
        'a sculpture of the {}.',
        'a bright photo of the {}.',
        'a cropped photo of a {}.',
        'a plastic {}.',
        'a photo of the dirty {}.',
        'a jpeg corrupted photo of a {}.',
        'a blurry photo of the {}.',
        'a photo of the {}.',
        'a good photo of the {}.',
        'a rendering of the {}.',
        'a {} in a video game.',
        'a photo of one {}.',
        'a doodle of a {}.',
        'a close-up photo of the {}.',
        'the origami {}.',
        'the {} in a video game.',
        'a sketch of a {}.',
        'a doodle of the {}.',
        'a origami {}.',
        'a low resolution photo of a {}.',
        'the toy {}.',
        'a rendition of the {}.',
        'a photo of the clean {}.',
        'a photo of a large {}.',
        'a rendition of a {}.',
        'a photo of a nice {}.',
        'a photo of a weird {}.',
        'a blurry photo of a {}.',
        'a cartoon {}.',
        'art of a {}.',
        'a sketch of the {}.',
        'a embroidered {}.',
        'a pixelated photo of a {}.',
        'itap of the {}.',
        'a jpeg corrupted photo of the {}.',
        'a good photo of a {}.',
        'a plushie {}.',
        'a photo of the nice {}.',
        'a photo of the small {}.',
        'a photo of the weird {}.',
        'the cartoon {}.',
        'art of the {}.',
        'a drawing of the {}.',
        'a photo of the large {}.',
        'a black and white photo of a {}.',
        'the plushie {}.',
        'a dark photo of a {}.',
        'itap of a {}.',
        'graffiti of the {}.',
        'a toy {}.',
        'itap of my {}.',
        'a photo of a cool {}.',
        'a photo of a small {}.',
        'a tattoo of the {}.',
    ]
    return prompt_templates

def prompt_engineering(classnames, topk=1, suffix='.'):
    prompt_templates = get_prompt_templates()
    temp_idx = np.random.randint(min(len(prompt_templates), topk))

    if isinstance(classnames, list):
        classname = random.choice(classnames)
    else:
        classname = classnames

    return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))

def prompt_engineering_llm(classnames, topk=1):
    prompt_templates = get_llm_prompt_templates()

    if isinstance(classnames, list):
        outputs = []
        for cls_name in classnames:
            temp_idx = np.random.randint(min(len(prompt_templates), topk))
            outputs += [prompt_templates[temp_idx].format(cls_name)]
        return outputs
    else:
        temp_idx = np.random.randint(min(len(prompt_templates), topk))
        return prompt_templates[temp_idx].format(classnames)