import os

from PIL import Image

import sys
import numpy as np

import jax
from tux import open_file

from lwm.vqgan import VQGAN
import json
import albumentations as A



preprocessor = A.Compose([
    A.LongestMaxSize(max_size=256),
    A.Resize(256, 256),
])

def process_images_batch(image_paths):
    images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
    processed_images = np.array([preprocessor(image=img)["image"] for img in images])
    processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
    return processed_images

def encode_images(vqgan, images):
    # Assuming VQGAN can encode a batch of images
    # Add batch dimension if VQGAN expects it explicitly
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

# JaxDistributedConfig.initialize(JaxDistributedConfig.get_default_config(),)
# prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
vqgan = VQGAN('checkpoints/lwm_checkpoints/vqgan', replicate=False)
# json_obj = json.load(open('/home/World-Model/data/bridge_singleview_total_action_processed.json', "r"))
# json_obj = json.load(open('/home/t-sye/World-Model/data/0731_multi_200_5hz.json', "r"))



image_paths = ['data/0731_multi_200/episode_189/step_0.jpg']
# image_paths = [elem['image'] for elem in batch_json]
processed_images = process_images_batch(image_paths)
encoded_images = encode_images(vqgan, processed_images)



final_elem = {}
enc_list = encoded_images[0].flatten().tolist()
print(enc_list)
final_elem['vision'] = list(map(str, enc_list))
# final_elem['answer'] = answer



