import os
import sys
import json
import argparse
import numpy as np
from PIL import Image
import jax
from tux import open_file
from lwm.vqgan import VQGAN
import albumentations as A
import random
import tqdm

def get_preprocessor(image_aug=False):
    if image_aug:
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
            A.RandomResizedCrop(256, 256, scale=[0.9, 0.9], ratio=[1.0, 1.0]),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=20),
        ])
    else:
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
        ])
    return preprocessor

def process_images_batch(image_paths, preprocessor):
    images = [np.array(Image.open(open_file(path, 'rb'))).astype(np.uint8) for path in image_paths]
    processed_images = np.array([preprocessor(image=img)["image"] for img in images])
    processed_images = (processed_images / 127.5 - 1.0).astype(np.float32)
    return processed_images

def encode_images(vqgan, images):
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

def main(args):
    print("Initializing preprocessor...")
    # Initialize preprocessor with or without augmentation
    preprocessor = get_preprocessor(image_aug=args.image_aug)

    print("Reading JSON file...")
    # read json
    with open('/mnt/default/lwm/data/train_data/sthv2_instruction.json', 'r') as f:
    # with open('data/sthv2_instruction.json', 'r') as f:
        sthv2_instruction = json.load(f)

    # Create a dictionary for fast lookup of instructions
    instruction_dict = {int(item['id']): item['label'] for item in sthv2_instruction}

    print("Initializing VQGAN...")
    vqgan = VQGAN(args.checkpoint_path, replicate=False)
    absolute_path = args.data_dir

    print("Reading image folders...")

    # all_folders = [entry.name for entry in os.scandir(absolute_path) if entry.is_dir()]
    all_folders = os.listdir(absolute_path)
    print(f"Total folders: {len(all_folders)}")


    total_list = []
    
    batch_size = args.batch_size
    start_index = args.start_index
    end_index = min(args.end_index, len(all_folders)) if args.end_index else len(all_folders)
    cnt = 0

    # for selected_folder in all_folders:
    # tqdm loop
    for selected_folder in tqdm.tqdm(all_folders[start_index:end_index]):
        cnt += 1 
        # if cnt == 10:
        #     break
        # print(f"Processing folder {selected_folder}", cnt)
        folder_id = int(selected_folder)

        if folder_id not in instruction_dict:
            continue  # Skip if no instruction is found for this folder

        instruction = instruction_dict[folder_id]

        all_files = sorted([f for f in os.listdir(os.path.join(absolute_path, selected_folder))])

        # Ensure to exclude the first 9 files and the last 9 files
        # if len(all_files) < 10:
        #     eligible_files = all_files 
        # elif len(all_files) < 20:
        #     eligible_files = all_files[9:]
        # else:
        #     eligible_files = all_files[9:-9]
        eligible_files = all_files

        # Randomly select 32 files from the eligible list
        selected_files = random.sample(eligible_files, min(batch_size, len(eligible_files)))

        # make the path absolute
        selected_files = [os.path.join(absolute_path, selected_folder, f) for f in selected_files]

        # print(f"Selected files: {selected_files}")



        processed_images = process_images_batch(selected_files, preprocessor)
        encoded_images = encode_images(vqgan, processed_images)
        for j, img_file in enumerate(selected_files):
            step = img_file.split('/')[-1].split('_')[-1].split('.')[0]
            final_elem = {
                'id': f'{selected_folder}_{step}',
                'image': img_file,
                'instruction': instruction,
                'vision': list(map(str, encoded_images[j].flatten().tolist())),
                'fields': '[instruction],[vision],delta'
            }
            total_list.append(final_elem)

    with open(args.output_file, 'w') as f:
        for traj in total_list:
            f.write(json.dumps(traj) + '\n')

if __name__ == "__main__":
    print("Starting the data preprocessing...")
    parser = argparse.ArgumentParser(description="Process images and generate JSONL outputs.")
    parser.add_argument('--output-file', type=str, default=f"{os.getenv('AMLT_OUTPUT_DIR')}/vqgan_inference_sthv2.jsonl", help='Path to output JSONL file.')
    parser.add_argument('--checkpoint-path', type=str, default='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan', help='Path to the VQGAN checkpoint directory.')
    # parser.add_argument('--data-dir', type=str, default='/mnt/default/lwm/data/rawframes', help='Directory containing image files.')
    parser.add_argument('--data-dir', type=str, default='/mnt/default/lwm/data/sthv2/rawframes2', help='Directory containing image files.')
    parser.add_argument('--batch-size', type=int, default=1, help='Batch size for processing images.')
    parser.add_argument('--start-index', type=int, default=0, help='Start index for processing the JSON objects.')
    parser.add_argument('--end-index', type=int, default=None, help='End index for processing the JSON objects.')
    parser.add_argument('--image-aug', type=int, default=0, help='Flag to apply image augmentations.')

    args = parser.parse_args()
    main(args)
