import os
import io, base64, pickle
import torch.nn.functional as F

import numpy as np
import torch as T
import torchvision.transforms as transforms
import sys
from CLIP import clip
import h5py
from PIL import Image

def printname(name):
    print(name)

if __name__=="__main__":
    import sys
    device = T.device('cuda')
    text = "The robot is grasping the red object"

    model_name = "clip_l14@336px"
    traj_num = 0
    # model_name = "clip"
    with T.no_grad():
        if model_name=="clip_b16":
            clip_model, transforms_clip = clip.load("ViT-B/16")
        elif model_name=="clip_l14":
            clip_model, transforms_clip = clip.load("ViT-L/14")
        elif model_name=="clip_l14@336px":
            clip_model, transforms_clip = clip.load("ViT-L/14@336px")
        elif model_name=="clip_b32":
            clip_model, transforms_clip = clip.load("ViT-B/32")
        elif model_name=="rn50":
            clip_model, transforms_clip = clip.load("RN50")
        elif model_name=="rn101":
            clip_model, transforms_clip = clip.load("RN101")
        elif model_name=="rn50_16":
            clip_model, transforms_clip = clip.load("RN50x16")
        elif model_name=="rn50_64":
            clip_model, transforms_clip = clip.load("RN50x64")
        else:
            raise NotImplementedError
                    
        clip_model = clip_model.to(device)
        image_model = clip_model.encode_image
        text_model = clip_model.encode_text

        imgs = []
        traj_all = h5py.File(h5_path)
        traj_0 = traj_all[f'traj_{traj_num}']
        length = len(traj_0["env_states"])
        imgs_dir = f'./demo_imgs/traj_{traj_num}'
        for i in range(length):
            image = Image.open(os.path.join(imgs_dir,f'{i}.jpg'))
            image = transforms_clip(image)
            imgs.append(image)
        image = T.stack(imgs).to(device)

        text = clip.tokenize([text]).to(device)
        # image_features = clip_model.encode_image(image)
        # text_features = clip_model.encode_text(text)

        logits_per_image, logits_per_text = clip_model(image, text)
        probs2 = logits_per_text.softmax(dim=-1).cpu().squeeze()
        idx = T.argmax(probs2).item()


