# import os

# from PIL import Image

# import sys
# import numpy as np

# import jax
# from tux import open_file

# from lwm.vqgan import VQGAN
# import json
# import albumentations as A


# preprocessor = A.Compose([
#     A.LongestMaxSize(max_size=256),
#     A.Resize(256, 256),
# ])

# def process_images_batch(image_paths):
#     print(image_paths)
#     images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
#     processed_images = np.array([preprocessor(image=img)["image"] for img in images])
#     processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
#     return processed_images

# def encode_images(vqgan, images):
#     # Assuming VQGAN can encode a batch of images
#     # Add batch dimension if VQGAN expects it explicitly
#     encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
#     return encoded

# # JaxDistributedConfig.initialize(JaxDistributedConfig.get_default_config(),)
# # prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
# vqgan = VQGAN('checkpoints/lwm_checkpoints/vqgan', replicate=False)
# # json_obj = json.load(open('/home/World-Model/data/bridge_singleview_total_action_processed.json', "r"))
# # json_obj = json.load(open('/home/t-sye/World-Model/data/0808_multiobject_sink_llava_5hz.json', "r"))

# json_obj = json.load(open('/home/t-sye/World-Model/data/814/human_motion_seen_llava/task.json', "r"))
# absolute_path = '/home/t-sye/World-Model/'
# total_list = []
# cnt = 0
# temp_store = []
# temp_cnt = 0

# batch_size = 1  # Adjust based on your GPU's memory
# total_list = []
# print(len(json_obj))
# start_index = 0
# # end_index = 320000
# # start_index = 1000000
# end_index = len(json_obj)
# # end_index = 1000000

# # start_index = 4662416
# # end_index = 4665423
# for i in range(start_index, end_index, batch_size):
#     batch_json = json_obj[i:i+batch_size]
#     image_paths = [absolute_path + elem['image'] for elem in batch_json]
#     # image_paths = [elem['image'] for elem in batch_json]
#     processed_images = process_images_batch(image_paths)
#     encoded_images = encode_images(vqgan, processed_images)
#     if i%1000==0:
#         print("index", i)

#     for j, json_elem in enumerate(batch_json):
#         final_elem = {}
#         image = json_elem['image']

#         # filter out step_0 to step_9
#         # step_info = int(json_elem["id"].split("/")[-1].split("_")[-1])
#         # if step_info < 10:
#         #     continue
#         instruction = json_elem['conversations'][0]['value']
#         instruction = instruction.replace('<image>\n', "")
#         # answer = json_elem['conversations'][1]['actions']
#         raw_actions = json_elem['conversations'][1]['raw_actions']
#         final_elem['id'] = json_elem['id']
#         final_elem['instruction'] = f"<s> You are a helpful assistant. USER: {instruction} ASSISTANT:"
#         enc_list = encoded_images[j].flatten().tolist()
#         final_elem['vision'] = list(map(str, enc_list))
#         # final_elem['answer'] = answer
#         final_elem['image'] = image
#         final_elem['raw_actions'] = list(map(str, raw_actions))
#         final_elem['fields'] = '[instruction],[vision],action'
#         total_list.append(final_elem)


# # filename = '/home/t-sye/World-Model/data/0808_multiobject_sink_llava_5hz.jsonl'
# filename = '/home/t-sye/World-Model/data/814_human_motion_seen_llava.jsonl'
# with open(filename, 'w') as f:
#     for traj in total_list:
#         f.write(json.dumps(traj) + '\n')




import os

from PIL import Image

import sys
import numpy as np

import jax
from tux import open_file

from lwm.vqgan import VQGAN
import json
import albumentations as A


preprocessor = A.Compose([
    A.LongestMaxSize(max_size=256),
    A.Resize(256, 256),
])

def process_images_batch(image_paths):
    print(image_paths)
    images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
    processed_images = np.array([preprocessor(image=img)["image"] for img in images])
    processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
    return processed_images

def encode_images(vqgan, images):
    # Assuming VQGAN can encode a batch of images
    # Add batch dimension if VQGAN expects it explicitly
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

# JaxDistributedConfig.initialize(JaxDistributedConfig.get_default_config(),)
# prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
vqgan = VQGAN('/root/checkpoints/lwm_checkpoints/vqgan', replicate=False)
# json_obj = json.load(open('/home/World-Model/data/bridge_singleview_total_action_processed.json', "r"))
# json_obj = json.load(open('/home/t-sye/World-Model/data/0808_multiobject_sink_llava_5hz.json', "r"))

json_obj = []
with open('/root/data/data_0919/data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl','r') as f:
    for elem in f:
        json_obj.append(json.loads(elem))
absolute_path = '/root/data/data_0919/'
total_list = []
cnt = 0
temp_store = []
temp_cnt = 0

batch_size = 1  # Adjust based on your GPU's memory
total_list = []
print(len(json_obj))
start_index = 0
# end_index = 320000
# start_index = 1000000
end_index = len(json_obj)
# end_index = 1000000

# start_index = 4662416
# end_index = 4665423
for i in range(start_index, end_index, batch_size):
    batch_json = json_obj[i:i+batch_size]
    image_paths = [absolute_path + elem['image'] for elem in batch_json]
    # image_paths = [elem['image'] for elem in batch_json]
    processed_images = process_images_batch(image_paths)
    encoded_images = encode_images(vqgan, processed_images)
    if i%1000==0:
        print("index", i)

    for j, json_elem in enumerate(batch_json):
        final_elem = {}
        image = json_elem['image']

        final_elem['id'] = json_elem['id']
        final_elem['instruction'] = json_elem['instruction']
        enc_list = encoded_images[j].flatten().tolist()
        final_elem['vision'] = list(map(str, enc_list))
        # final_elem['answer'] = answer
        final_elem['image'] = json_elem['image']
        final_elem['raw_actions'] = json_elem['raw_actions']
        final_elem['action'] = json_elem['action']
        final_elem['fields'] = '[instruction],[vision],action'
        total_list.append(final_elem)


# filename = '/home/t-sye/World-Model/data/0808_multiobject_sink_llava_5hz.jsonl'
filename = '/root/data/data_0919/data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
with open(filename, 'w') as f:
    for traj in total_list:
        f.write(json.dumps(traj) + '\n')


