'''
 * The Inference of RAM and Tag2Text Models
 * Written by Xinyu Huang
'''
import torch


def inference_tag2text(image, model, input_tag="None"):

    with torch.no_grad():
        caption, tag_predict = model.generate(image,
                                              tag_input=None,
                                              max_length=50,
                                              return_tag_predict=True)

    if input_tag == '' or input_tag == 'none' or input_tag == 'None':
        return tag_predict[0], None, caption[0]

    # If user input specified tags:
    else:
        input_tag_list = []
        input_tag_list.append(input_tag.replace(',', ' | '))

        with torch.no_grad():
            caption, input_tag = model.generate(image,
                                                tag_input=input_tag_list,
                                                max_length=50,
                                                return_tag_predict=True)

        return tag_predict[0], input_tag[0], caption[0]


def inference_ram(image, model):

    with torch.no_grad():
        tags, tags_chinese = model.generate_tag(image)

    return tags[0],tags_chinese[0]


def inference_ram_openset(image, model, return_logits=True):

    with torch.no_grad():
        outs = model.generate_tag_openset(image, return_logits=return_logits)
        if return_logits:
            tags, logits = outs[0], outs[1]
        else:
            tags = outs

    if return_logits:
        return tags, logits
    else:
        return tags
