from .few_shot_example import few_shots

def zs_single_sep(ascii_repr, text):
    '''
    zero-shot, single character recognition, separate prompt for digit and letter
    '''
    if text.isdigit():
        return f"The following is an ascii art of a digit, which involves the smart placement of typed special characters or letters to make a visual shape that is spread over multiple lines of text. Identify the digit in 0-9, and reply in this way, ```The answer is [Your prediction]```.\nThe ascii art is:\n{ascii_repr}"
    else:
        return f"The following is an ascii art of a letter, which involves the smart placement of typed special characters or letters to make a visual shape that is spread over multiple lines of text. Identify the letter in a-z, and reply in this way, ```The answer is [Your prediction]```.\nThe ascii art is:\n{ascii_repr}"
    

def zs_single_all(ascii_repr, text):
    '''
    zero-shot, single character recognition, single prompt for digit and letter
    '''
    return f"The following is an ascii art of a digit or letter, which involves the smart placement of typed special characters or letters to make a visual shape that is spread over multiple lines of text. Identify the digit in 0-9 or letter in a-z, and reply in this way, ```The answer is [Your prediction]```.\nThe ascii art is:\n{ascii_repr}"
    

def zs_long(ascii_repr, text):
    '''
    zero-shot, long text
    '''
    length = len(text)
    return f"The following is an ascii art of a string consist of digits and letters, which involves the smart placement of typed special characters or letters to make a visual shape that is spread over multiple lines of text. Identify the string that may have digit in 0-9 and/or letters in A-Z, and reply in this way, ```The answer is [Your prediction]```.\n\n{ascii_repr}"
    
    
def few_shot(ps, ascii_repr, text, model):
    '''
    few-shot prmompt setting

    fs-icl/cot-n-s/l
    '''
    setting, num, task = ps.split('-')[1:]
    num = int(num)
    assert num <= 4, "Number of examples should be less than 8"
    if task == 's':
        assert len(text) == 1
    elif task == 'l':
        assert len(text) > 1

    examples_setting = f"{setting}-{task}"
    examples = few_shots[examples_setting][:num]


    if task == 's':
        # prompt = f"Follow the given examples of ascii art prediction task, where each ascii art is either a digit in 0-9 or a letter in a-z:\n\n"
        if 'llama' in model:
            prompt = f"Follow the given examples of ascii art prediction task, which is either a digit in 0-9 or a letter in a-z, and reply in this way, ```The answer is [Your prediction]```. Only answer the last ascii art prediction:\n\n"
        else:
            prompt = f"Follow the given examples of ascii art prediction task, which is either a digit in 0-9 or a letter in a-z, and reply in this way, ```The answer is [Your prediction]```:\n\n"
        
    elif task == 'l':
        prompt = f"Follow the given examples of ascii art prediction task, where each ascii art is a string of 2 to 4 characters, consist of digits in 0-9 and/or letters in a-z. reply in this way, ```The answer is [Your prediction]```:\n\n"
    else:
        raise ValueError(f"task should be s or l, but got {task}")
    

    for i, example in enumerate(examples):
        # prompt += f"Ascii art:\n{example['ascii_repr']}\nThe answer is {example['text']}\n\n"
        prompt += f"{example}"
    prompt += f"Ascii art:\n{ascii_repr}\n"
    return prompt


def mnist_prompt(grid):
    return f'''The following is the a 28*28 array for a white-black color image, where the image is a digit in 0-9. In the grid, 0 is black, and 1 is white. Predict what the image is, and reply in this way, ```The answer is [Your prediction]```\nGrid:\n```{grid}```'''

def task_prompt(ps, atext, text, model):
    # zero-shot
    if ps == 'zs-s-easy':
        assert len(text) == 1
        return zs_single_sep(atext, text)
    elif ps == 'zs-s-hard':
        assert len(text) >= 1
        return zs_single_all(atext, text)
    elif ps == 'zs-l':
        return zs_long(atext, text)
    
    # few-shot
    elif 'fs' in ps:
        return few_shot(ps, atext, text, model)
    

    elif 'mnist':
        return mnist_prompt(atext)
    else:
        raise ValueError(f'Invalid prompt selector {ps}')
