import argparse
from collections import deque
from typing import Optional, Sequence
import numpy as np
from PIL import Image
import csv
from absl import flags
import json

import torch
from torchvision import transforms as T, utils
import sys
import os

sys.path.append('/root/World-Model/Phenaki')
from phenaki_pytorch import CViViT_single_linear_nsvq_3


image_size = 256
transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((image_size, image_size)),
            # T.CenterCrop(image_size),
            T.ToTensor()
        ])






# load_checkpoint "params::/home/t-sye/World-Model/checkpoints/bridge_carrot_256_batch_128_multitask_40traj/streaming_params_${checkpoint}"
# update-llama-config "dict(action_vocab_size=245,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \

# python -m lwm.latent_lwm_checking --load_cvivit_checkpoint "/home/t-sye/World-Model/phenaki/vast_ai_backup/whole_window3/ddp_gpu_2_batch224_layer8_seq4_code2/vae.20000.pt" --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=2,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LWM Inference script")
    parser.add_argument('--tokens_per_delta', type=int, default=4, help='Tokens per delta')
    parser.add_argument('--codebook_size', type=int, default=8)
    parser.add_argument('--load_cvivit_checkpoint', type=str, default="/root/checkpoints/lwm_checkpoints/vqgan")



    # Add more arguments as needed
    
    
    args = parser.parse_args()

    
    cvivit = CViViT_single_linear_nsvq_3(
        dim = 1024,
        quant_dim=32,
        codebook_size = args.codebook_size, #32,
        image_size = 256,
        patch_size = 32,
        temporal_patch_size = 2,
        spatial_depth = 8,
        temporal_depth = 8,
        dim_head = 64,
        heads = 16,
        lookup_free_quantization=False,
        use_vgg_and_gan=False,
        code_seq_len=args.tokens_per_delta,
    ).cuda()
    print("loading")

    cvivit.load(args.load_cvivit_checkpoint)
    print("loaded")

    # load jsonl file
    # json_obj = []
    # with open('/root/World-Model/data/latentvla_openx_inference.jsonl', 'r') as f:
    #     for line in f:
    #         json_obj.append(json.loads(line))
    

    indices_list = []
    os.makedirs(f'inference_img', exist_ok=True)
    # for index, json_elem in enumerate(json_obj):
    generated_images=[]

    # first_image = json_elem['image'][0]
    # next_image = json_elem['image'][1]
    # first_image = first_image.replace('/home/t-sye', '/root/data')
    # next_image = next_image.replace('/home/t-sye', '/root/data')
    # instruction = json_elem['instruction'].replace('<s> You are a helpful assistant. USER: ', "").replace(' ASSISTANT:', "")
    # first_image = '/root/data/bridge_img/221.jpg'
    # indices = [2, 0, 0, 0]

    first_image = '/root/World-Model/lwm/analysis/test4.jpg'
    # first_image = '/root/data/bridge_img/1904.jpg'
    indices = [1,4,1,4]
    output_image = '/root/World-Model/lwm/analysis/test5.jpg'

    next_image = first_image
    image_1 = Image.open(first_image)
    image_2 = Image.open(next_image)

    # generated_images.append(image_1)
    # generated_images.append(image_2)    
    

    img = transform(image_1).unsqueeze(1)
    img2 = transform(image_2).unsqueeze(1)
    img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()

    # indices = json_elem['latent_action']  
        # image = np.array(image_1)

        
        # indices = lwm.inference(image, instruction)
        # print("indices", indices)
        # indices_list.append({"image": json_elem['image'], "latent_action": indices.tolist()})

    video = cvivit.inference(img, user_action_token_num=indices)
    video = torch.clamp(video, 0.0, 1.0)
    video *= 255.0
    video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
    print("video", video[0].shape)  
    video = [Image.fromarray(image) for image in video]
    generated_images.append(video[0])

    total_image = Image.new('RGB', (256, 256))
    x_offset = 0
    for im in generated_images:
        total_image.paste(im, (x_offset,0))
        x_offset += 256
    total_image.save(output_image)

    # save the list of lists to a jsonl file
    # with open('/root/World-Model/data/latentvla_openx_inference.jsonl', 'w') as f:
    #     for item in indices_list:
    #         f.write(json.dumps(item))
    #         f.write("\n")
    
    
    # first_image = json_obj[0]['image'][0]
    # # TODO: load next image
    # next_image = json_obj[0]['image'][1]
    # # print(json_obj[0])
    # instruction = json_obj[0]['instruction'].replace('<s> You are a helpful assistant. USER: ', "").replace(' ASSISTANT:', "")
    # image_1 = Image.open(first_image)
    # image_2 = Image.open(next_image)
    # image = np.array(image_1)
    # generated_images.append(image_1)
    # generated_images.append(image_2)

    # img = transform(image_1).unsqueeze(1)
    # img2 = transform(image_2).unsqueeze(1)
    # img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()




    # JaxDistributedConfig.initialize(args.jax_distributed)
    # set_random_seed(args.seed)

    # lwm = LWMInference(
    #     image_size=256,
    #     # codebook_size=args.codebook_size,
    #     tokens_per_delta=args.tokens_per_delta,
    #     vqgan_checkpoint=args.vqgan_checkpoint,
    #     vocab_file=args.vocab_file,
    #     multi_image=args.multi_image,
    #     jax_distributed=args.jax_distributed,
    #     seed=args.seed,
    #     mesh_dim=args.mesh_dim,
    #     dtype=args.dtype,
    #     load_llama_config=args.load_llama_config,
    #     update_llama_config=args.update_llama_config,
    #     load_checkpoint=args.load_checkpoint,
    #     tokenizer=args.tokenizer,
    #     llama=args.llama
    # )
    # indices = lwm.inference(image, instruction)
    # print("indices", indices)

#     video = cvivit.inference(img, user_action_token_num=indices)
#     video = torch.clamp(video, 0.0, 1.0)
#     video *= 255.0
#     video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
#     print("video", video[0].shape)
#     video = [Image.fromarray(image) for image in video]
#     generated_images.append(video[0])

#     total_image = Image.new('RGB', (256*3, 256))
#     x_offset = 0
#     for im in generated_images:
#         total_image.paste(im, (x_offset,0))
#         x_offset += 256
#     total_image.save(f'test.jpg')

    
# # python -m lwm.latent_lwm_checking --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=2,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "params::checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"