import os
import sys
import json
import argparse
import numpy as np
from PIL import Image
import jax
from tux import open_file
from lwm.vqgan import VQGAN
import albumentations as A

def get_preprocessor(image_aug=False):
    if image_aug:
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
            A.RandomResizedCrop(256, 256, scale=[0.9, 0.9], ratio=[1.0, 1.0]),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=20),
        ])
    else:
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
        ])
    return preprocessor

def process_images_batch(image_paths, preprocessor):
    images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
    processed_images = np.array([preprocessor(image=img)["image"] for img in images])
    processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
    return processed_images

def encode_images(vqgan, images):
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

def main(args):
    # Initialize preprocessor with or without augmentation
    preprocessor = get_preprocessor(image_aug=args.image_aug)

    # read json
    with open('data/sthv2_instruction.json', 'r') as f:
        sthv2_instruction = json.load(f)

    # Create a dictionary for fast lookup of instructions
    instruction_dict = {int(item['id']): item['label'] for item in sthv2_instruction}

    vqgan = VQGAN(args.checkpoint_path, replicate=False)
    absolute_path = args.data_dir

    all_folders = [folder for folder in os.listdir(absolute_path) if os.path.isdir(os.path.join(absolute_path, folder))]

    total_list = []
    
    batch_size = args.batch_size
    start_index = args.start_index
    end_index = min(args.end_index, len(all_folders)) if args.end_index else len(all_folders)
    cnt = 0

    for selected_folder in all_folders:
        cnt += 1 
        print(f"Processing folder {selected_folder}", cnt)
        folder_id = int(selected_folder)

        if folder_id not in instruction_dict:
            continue  # Skip if no instruction is found for this folder

        instruction = instruction_dict[folder_id]

        for i in range(10, len(os.listdir(f'/home/t-sye/rawframes/{selected_folder}')) - 10, batch_size):
            img_file_list = []
            for j in range(batch_size):
                if i + j > len(os.listdir(f'/home/t-sye/rawframes/{selected_folder}')) - 1:
                    break
                if i + j > 99:
                    img_file_list.append(f'/home/t-sye/rawframes/{selected_folder}/img_00{i + j}.jpg')
                else:
                    img_file_list.append(f'/home/t-sye/rawframes/{selected_folder}/img_000{i + j}.jpg')

            processed_images = process_images_batch(img_file_list, preprocessor)
            encoded_images = encode_images(vqgan, processed_images)
            for j, img_file in enumerate(img_file_list):
                final_elem = {
                    'id': f'{selected_folder}_{i + j}',
                    'image': img_file,
                    'instruction': instruction,
                    'vision': list(map(str, encoded_images[j].flatten().tolist())),
                    'fields': '[instruction],[vision],delta'
                }
                total_list.append(final_elem)

    with open(args.output_file, 'w') as f:
        for traj in total_list:
            f.write(json.dumps(traj) + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process images and generate JSONL outputs.")
    parser.add_argument('--output-file', type=str, required=True, help='Path to output JSONL file.')
    parser.add_argument('--checkpoint-path', type=str, default='checkpoints/lwm_checkpoints/vqgan', help='Path to the VQGAN checkpoint directory.')
    parser.add_argument('--data-dir', type=str, default='/home/t-sye/rawframes/', help='Directory containing image files.')
    parser.add_argument('--batch-size', type=int, default=1, help='Batch size for processing images.')
    parser.add_argument('--start-index', type=int, default=0, help='Start index for processing the JSON objects.')
    parser.add_argument('--end-index', type=int, default=None, help='End index for processing the JSON objects.')
    parser.add_argument('--image-aug', type=int, default=0, help='Flag to apply image augmentations.')

    args = parser.parse_args()
    main(args)
