##############################################################################################
# In this section, we set the user authentication, app ID, model details, and the location
# of the image we want as an input. Change these strings to run your own example.
#############################################################################################

# USER_ID = 'adamcajf'
USER_ID = 'zsdl3djys8p7'
# Your PAT (Personal Access Token) can be found in the portal under Authentification
# PAT = 'bd0fdf6f426a406192eba0270573258f'
PAT = '3bc16064701f4b6b8dcc7b0c0e57c997'
# APP_ID = 'image-embedding'
APP_ID = 'my-first-application'
# Change these to whatever model and image input you want to use
MODEL_ID = 'general-image-embedding-vit'
IMAGE_FILE_LOCATION = './output/samples/leptodactylus_pentadactylus_s_0000046.jpg'
# This is optional. You can specify a model version or the empty string for the default
MODEL_VERSION_ID = 'a78386d5142c4025ac42272b86f06134'

############################################################################
# YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
############################################################################

from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2

channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

metadata = (('authorization', 'Key ' + PAT),)

userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)

def get_embedding(file_bytes):
    post_model_outputs_response = stub.PostModelOutputs(
        service_pb2.PostModelOutputsRequest(
            user_app_id=userDataObject,
            # The userDataObject is created in the overview and is required when using a PAT
            model_id=MODEL_ID,
            version_id=MODEL_VERSION_ID,
            # This is optional. Defaults to the latest model version
            inputs=[
                resources_pb2.Input(
                    data=resources_pb2.Data(
                        image=resources_pb2.Image(
                            base64=file_bytes
                        )
                    )
                )
            ]
        ),
        metadata=metadata
    )
    if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
        print(post_model_outputs_response.status)
        # raise Exception(
        #     "Post model outputs failed, status: " + post_model_outputs_response.status.description)
        return []

    # Take the embeddings in the string format
    output = post_model_outputs_response.outputs[0]
    # split the values for the 768 dimensional embeddings
    str_embedding = str(output.data.embeddings[0]).split('\nvector:')
    # curate the first item
    str_embedding[0] = str_embedding[0].split('vector:')[1]
    # curate the last item
    str_embedding[-1] = str_embedding[-1].split('\n')[0]
    embedding = [float(x) for x in str_embedding]
    # print('embedding: ', embedding)
    return embedding


if __name__ == "__main__":
    with open(IMAGE_FILE_LOCATION, "rb") as f:
        file_bytes = f.read()

    # check an image from ImageNet
    import torchvision.datasets as datasets
    traindir = '/home/nicolas/data/imagenet'
    train_dataset = datasets.ImageFolder(traindir)
    img = None
    for image, _ in train_dataset:
        print(image)
        img = image
        break
    print('img: ', img)

    get_embedding(file_bytes=file_bytes)
