import os

from PIL import Image

import sys
import numpy as np

import jax
from tux import open_file

from lwm.vqgan import VQGAN
import json
import albumentations as A

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

import albumentations
def _process_frame(image, size):
    img_array = np.array(image).astype(np.uint8)
    # print("image array", img_array.shape)
            
    preprocessor_finetune = albumentations.Compose([
        albumentations.LongestMaxSize(max_size=256),  # Resize the longest side to 256
        # albumentations.PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=(0, 0, 0))
        albumentations.Resize(256, 256), 
    ])
    image_vqgan = preprocessor_finetune(image=img_array)["image"]
    image_vqgan = (image_vqgan/127.5 - 1.0).astype(np.float32)
    return image_vqgan

preprocessor = A.Compose([
    A.LongestMaxSize(max_size=256),
    A.Resize(256, 256),
])

def process_images_batch(image_paths):
    images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
    processed_images = np.array([preprocessor(image=img)["image"] for img in images])
    processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
    return processed_images

def encode_images(vqgan, images):
    # Assuming VQGAN can encode a batch of images
    # Add batch dimension if VQGAN expects it explicitly
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

# JaxDistributedConfig.initialize(JaxDistributedConfig.get_default_config(),)
# prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
vqgan = VQGAN('checkpoints/lwm_checkpoints/vqgan', replicate=False)
# json_obj = json.load(open('/home/World-Model/data/bridge_singleview_total_action_processed.json', "r"))
# json_obj = json.load(open('/home/World-Model/data/bridge_rollout_carrot.json', "r"))

# load jsonl file
json_obj = []
with open('/home/t-sye/World-Model/Phenaki/analysis/bridge_window3_action_whole.jsonl', 'r') as f:
    for line in f:
        json_obj.append(json.loads(line))

total_list = []
cnt = 0
temp_store = []
temp_cnt = 0

batch_size = 16  # Adjust based on your GPU's memory
total_list = []
print(len(json_obj))
start_index = 0
# end_index = 320000
# start_index = 1000000
end_index = len(json_obj)
# end_index = 1000000

# start_index = 4662416
# end_index = 4665423
for i in range(start_index, end_index, batch_size):
    batch_json = json_obj[i:i+batch_size]
    # image_paths = ['../World-As-Code/llava/playground/data/' + elem['image'] for elem in batch_json]
    image_paths = [elem['image'] for elem in batch_json]
    processed_images = process_images_batch(image_paths)
    encoded_images = encode_images(vqgan, processed_images)
    if i%1000==0:
        print("index", i)

    for j, json_elem in enumerate(batch_json):
        final_elem = {}
        image = json_elem['image']
        # instruction = json_elem['conversations'][0]['value']
        # instruction = instruction.replace('\n<image>', "")
        instruction = json_elem['language_instruction']
        # answer = json_elem['conversations'][1]['actions']
        # raw_actions = json_elem['conversations'][1]['raw_actions']
        raw_actions = json_elem["action"]
        final_elem['instruction'] = f"<s> You are a helpful assistant. USER: {instruction} ASSISTANT:"
        enc_list = encoded_images[j].flatten().tolist()
        final_elem['vision'] = list(map(str, enc_list))
        # final_elem['answer'] = answer
        final_elem['image'] = image
        final_elem['raw_actions'] = list(map(str, raw_actions))
        final_elem['fields'] = '[instruction],[vision],action'
        total_list.append(final_elem)


filename = '/home/t-sye/World-Model/data/bridge_window3_action_whole_lwm_format.jsonl'
with open(filename, 'w') as f:
    for traj in total_list:
        f.write(json.dumps(traj) + '\n')


