import os
import sys
import json
import argparse
import numpy as np
from PIL import Image
import jax
# from tux import open_file
from lwm.vqgan import VQGAN
import albumentations as A
import random
import tqdm
import sys 
sys.path.append('Phenaki')
from phenaki_pytorch import get_vla_dataset
from torch.utils.data import DataLoader
import torch
from phenaki_pytorch import CViViT_single_linear_nsvq_3

def get_preprocessor(image_aug=False):
    if image_aug:
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
            A.RandomResizedCrop(256, 256, scale=[0.9, 0.9], ratio=[1.0, 1.0]),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
            A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=20),
        ])
    else:
        preprocessor = A.Compose([
            A.LongestMaxSize(max_size=256),
            A.Resize(256, 256),
        ])
    return preprocessor


def encode_images(vqgan, images):
    encoded = jax.device_get(vqgan.encode(images))[1].astype(int)
    return encoded

def main(args):
    cvivit = CViViT_single_linear_nsvq_3(
        dim = 1024,
        quant_dim=32,
        codebook_size = args.codebook_size,
        image_size = 256,
        patch_size = 32,
        temporal_patch_size = 2,
        spatial_depth = args.layer, #8
        temporal_depth = args.layer, #8
        dim_head = 64,
        heads = 16,
        lookup_free_quantization=False,
        use_vgg_and_gan=False,
        code_seq_len=args.code_seq_len,
        encode_last_frame_only=False,
        token_level_recon=False,
        nsvq_global=False, # if true, use multiple nsvq codebooks
    ).cuda()


    cvivit.load(args.cvivit_checkpoint)
    # cvivit.load('/home/t-sye/World-Model/checkpoints/phenaki/vast_ai_backup/whole_window3/ddp_gpu_2_batch224_layer8_seq4_code8/vae.20000.pt')
    print("Initializing preprocessor...")
    # Initialize preprocessor with or without augmentation
    preprocessor = get_preprocessor(image_aug=args.image_aug)

    print("Initializing VQGAN...")
    vqgan = VQGAN(args.checkpoint_path, replicate=False)

    print("Reading image folders...")

    ds = get_vla_dataset(train=True, data_root_dir=args.root_folder, dataset_name=args.open_x_dataset_name, shuffle_buffer_size=10000, shuffle=False, index=args.index, divider=args.divider)
    original_length = len(ds)
    # start_step = args.index * args.interval
    # print("@@",start_step)
    # ds.dataset = ds.dataset.shard(num_shards=args.divider, index=args.index)
    # ds.dataset_length = int(original_length / args.divider)
    # ds.dataset = ds.dataset.skip(start_step).take(args.interval)
    # ds.dataset_length = args.interval

    dl = DataLoader(
                ds,
                batch_size = args.batch_size,
                shuffle = False,
                num_workers = 1
            )

    print("length ds", len(ds))


    def cycle_dl(dl):
        while True:
            for batch in dl:
                yield batch



    dl = cycle_dl(dl)



    total_list = []
    traj_cnt = 0
    first = True
    finish = False
    # for i in tqdm.tqdm(range(min(args.interval // args.batch_size, original_length - start_step))):
    # for i in tqdm.tqdm(range(len(ds) // args.batch_size)):
    # for i in tqdm.tqdm(range(len(ds) // args.batch_size // args.divider)):
    for i in tqdm.tqdm(range(12000)):
        batches = next(dl)

        imgs = batches['pixel_values']
        imgs = imgs.cuda()

        imgs = imgs.squeeze(1)

        # img value is from 0 to 1
        latents = cvivit.inference(imgs, return_only_codebook_ids=True)
        # print(latents)
        latents = latents.cpu().detach().numpy()
        # print(latents)

        # print("batches", batches.keys())
        curr_images = [batches['pixel_values'][step][:, :, 0].squeeze(0) for step in range(len(batches['pixel_values']))]
        is_last = batches['is_last']
        for is_last_elem in is_last:
            if is_last_elem:
                traj_cnt += 1
                print("traj_cnt", traj_cnt)
                
        # print("is last", is_last)
        # print the min and max of the images

        # Stack them into a single tensor along a new dimension (e.g., the batch dimension)
        imgs = torch.stack(curr_images, dim=0).permute(0, 2, 3, 1) 

        # curr_image = batches['pixel_values'][step][:, :, :1]
        # print(imgs)
        # convert to numpy
        imgs = imgs.numpy()


        encoded_images = encode_images(vqgan, imgs)

        for j, img_file in enumerate(imgs):
            # raw_action = batches['action'][j]
            # print(raw_action)
            
            # print("dataset", batches['dataset_name'][j])
            final_elem = {
                'dataset_name': batches['dataset_name'][j].decode('utf-8'),
                'instruction': batches['lang'][j],
                'vision': list(map(str, encoded_images[j].flatten().tolist())),
                'fields': '[instruction],[vision],delta',
                'is_last': batches['is_last'][j].item(),
                'raw_action': batches['action'][j].tolist(),
                'delta': latents[j].tolist(),
                'fields': '[instruction],[vision],delta',
            }
            # if final_elem in total_list:
            #     print(final_elem)
            #     finish = True
            #     break
            total_list.append(final_elem)
        # if finish: 
        #     break

    with open(args.output_file, 'w') as f:
        for traj in total_list:
            f.write(json.dumps(traj) + '\n')

if __name__ == "__main__":
    print("Starting the data preprocessing...")
    parser = argparse.ArgumentParser(description="Process images and generate JSONL outputs.")
    parser.add_argument('--output-file', type=str, default=f"/mnt/default/lwm/data/open-x/vqgan_inference_openx.jsonl", help='Path to output JSONL file.')
    parser.add_argument('--checkpoint-path', type=str, default='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan', help='Path to the VQGAN checkpoint directory.')
    parser.add_argument('--batch-size', type=int, default=1, help='Batch size for processing images.')
    parser.add_argument('--image-aug', type=int, default=0, help='Flag to apply image augmentations.')
    parser.add_argument('--root_folder', type=str, default='/mnt/default/lwm/data/open-x/datasets', help='Root folder of the dataset')
    parser.add_argument('--open_x_dataset_name', type=str, default='oxe_magic_soup_plus', help='Name of the OpenX dataset')
    parser.add_argument('--index', type=int, default=0, help='Index of the dataset')
    parser.add_argument('--divider', type=int, default=100, help='Interval of the dataset')
    parser.add_argument('--interval', type=int, default=200000, help='Interval of the dataset')
    parser.add_argument('--layer', type=int, default=8, help='Layer of the model')
    parser.add_argument('--codebook_size', type=int, default=8, help='Codebook size of the model')
    parser.add_argument('--code_seq_len', type=int, default=4, help='Code sequence length of the model')
    # parser.add_argument('--cvivit_checkpoint', type=str, default='/mnt/default/lwm/data/checkpoints/phenaki/vast_ai_backup/open-x/openx_c8_s9_l8_b512_100K_gpu8_noglobal_pre/vae.65000.pt', help='Path to the CViViT checkpoint')
    parser.add_argument('--cvivit_checkpoint', type=str, default='/mnt/default/lwm/data/checkpoints/phenaki/vast_ai_backup/open-x/openx_c8_s4_l8_b448_100K_gpu2_noglobal_re/vae.85000.pt', help='Path to the CViViT checkpoint')
    args = parser.parse_args()
    main(args)
