import os
import sys
import json
import argparse
import numpy as np
from PIL import Image
import jax
from tux import open_file
from lwm.vqgan import VQGAN
import albumentations as A

def get_preprocessor(image_aug=False):
    if image_aug:
        print("@@@@")
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
            A.RandomResizedCrop(256, 256, scale=[0.9, 0.9], ratio=[1.0, 1.0]),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=20),
        ])
    else:
        print("no img aug")
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
        ])
    return preprocessor

def process_images_batch(image_paths, preprocessor):
    print(image_paths)
    images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
    processed_images = np.array([preprocessor(image=img)["image"] for img in images])
    processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
    return processed_images

def encode_images(vqgan, images):
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

def main(args):
    # Initialize preprocessor with or without augmentation
    preprocessor = get_preprocessor(image_aug=args.image_aug)

    vqgan = VQGAN(args.checkpoint_path, replicate=False)
    json_obj = json.load(open(args.input_json, "r"))
    # load jsonl file
    # json_obj = []
    # with open(args.input_json, 'r') as f:
    #     for line in f:
    #         json_obj.append(json.loads(line))
    absolute_path = args.data_dir
    total_list = []
    
    batch_size = args.batch_size
    start_index = args.start_index
    end_index = min(args.end_index, len(json_obj)) if args.end_index else len(json_obj)

    for i in range(start_index, end_index, batch_size):
        batch_json = json_obj[i:i+batch_size]
        # print(len(batch_json))
        image_paths = [os.path.join(absolute_path, elem['image'].replace('/rummy/llava', '/918_rummy')) for elem in batch_json]
        # processed_images = process_images_batch(image_paths, preprocessor)
        # encoded_images = encode_images(vqgan, processed_images)

        # decode = vqgan.decode(encoded_images)
        # convert to PIL images
        # decode = [Image.fromarray(np.uint8(img * 255.0)) for img in decode]
        # for j, img in enumerate(decode):
        #     img.save(f'{i+j}.png')
        # exit()
        
        if i % 1000 == 0:
            print("Processing index", i)

        for j, json_elem in enumerate(batch_json):
            final_elem = {}
            image = json_elem['image']
            instruction = json_elem['conversations'][0]['value'].replace('<image>\n', "")
            raw_actions = json_elem['conversations'][1]['raw_actions']
            final_elem['instruction'] = f"<s> You are a helpful assistant. USER: {instruction} ASSISTANT:"
            # final_elem['instruction'] = json_elem['instruction']
            # raw_actions = json_elem['raw_actions']
            # enc_list = encoded_images[j].flatten().tolist()
            final_elem['id'] = json_elem['id']
            # final_elem['vision'] = list(map(str, enc_list))
            final_elem['image'] = image
            final_elem['raw_actions'] = list(map(str, raw_actions))
            # final_elem['raw_actions'] = json_elem['raw_actions']
            # final_elem['action'] = json_elem['action']
            final_elem['fields'] = '[instruction],[vision],action'
            total_list.append(final_elem)

    with open(args.output_file, 'w') as f:
        for traj in total_list:
            f.write(json.dumps(traj) + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process images and generate JSONL outputs.")
    parser.add_argument('--input-json', type=str, required=True, help='Path to input JSON file.')
    parser.add_argument('--output-file', type=str, required=True, help='Path to output JSONL file.')
    parser.add_argument('--checkpoint-path', type=str, default='checkpoints/lwm_checkpoints/vqgan', help='Path to the VQGAN checkpoint directory.')
    parser.add_argument('--data-dir', type=str, default='/home/t-sye/World-Model/', help='Directory containing image files.')
    parser.add_argument('--batch-size', type=int, default=1, help='Batch size for processing images.')
    parser.add_argument('--start-index', type=int, default=0, help='Start index for processing the JSON objects.')
    parser.add_argument('--end-index', type=int, default=None, help='End index for processing the JSON objects.')
    parser.add_argument('--image-aug', type=int, default=0, help='Flag to apply image augmentations.')

    args = parser.parse_args()
    main(args)
