import argparse
from collections import deque
from typing import Optional, Sequence
import numpy as np
from PIL import Image
import csv
from absl import flags
import json
from lwm.delta_sampler_override_bridge import DeltaSampler
from lwm.delta_llama import VideoLLaMAConfig

from tux import define_flags_with_default, JaxDistributedConfig, set_random_seed
# import torch
# from torchvision import transforms as T, utils
import sys

# sys.path.append('/root/World-Model/Phenaki')
# from phenaki_pytorch import CViViT_single_linear_nsvq_3


# image_size = 256
# transform = T.Compose([
#             T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#             T.Resize((image_size, image_size)),
#             # T.CenterCrop(image_size),
#             T.ToTensor()
#         ])



class FLAGSClass:
    def __init__(self, flag_dict):
        for key, value in flag_dict.items():
            setattr(self, key, value)

class LWMInference:
    def __init__(
        self,
        image_size: int = 256,
        **kwargs,
    ) -> None:
        print("kwargs", kwargs)
        flags = FLAGSClass(kwargs)

        self.model = DeltaSampler(FLAGS=flags)
        self.image_size = image_size
        self.tokens_per_delta = kwargs['tokens_per_delta']
        self.task_description = None

    def inference(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
        assert image.dtype == np.uint8
        image = Image.fromarray(image)
        prompts = [{'image': [image], 'question': task_description}]
        
        user_action_token_num, vision_output, text_output = self.model(prompts)
        norm_raw_actions = user_action_token_num[0]
        print("norm raw actions", norm_raw_actions)
        indices = norm_raw_actions

        return indices


# load_checkpoint "params::/home/t-sye/World-Model/checkpoints/bridge_carrot_256_batch_128_multitask_40traj/streaming_params_${checkpoint}"
# update-llama-config "dict(action_vocab_size=245,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \

# python -m lwm.latent_lwm_checking --load_cvivit_checkpoint "/home/t-sye/World-Model/phenaki/vast_ai_backup/whole_window3/ddp_gpu_2_batch224_layer8_seq4_code2/vae.20000.pt" --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=2,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LWM Inference script")
    parser.add_argument('--tokens_per_delta', type=int, default=4, help='Tokens per delta')
    parser.add_argument('--vqgan_checkpoint', type=str, default="/root/checkpoints/lwm_checkpoints/vqgan")
    parser.add_argument('--vocab_file', type=str, default='/root/checkpoints/lwm_checkpoints/tokenizer.model')
    parser.add_argument('--multi_image', type=int, default=1)
    parser.add_argument('--jax_distributed', type=dict, default=JaxDistributedConfig.get_default_config())
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--mesh_dim', type=str, default="1,-1,1,1")
    parser.add_argument('--dtype', type=str, default="bf16")
    parser.add_argument('--load_llama_config', type=str, default="7b")
    parser.add_argument('--update_llama_config', type=str, default="")
    parser.add_argument('--load_checkpoint', type=str, default="")
    parser.add_argument('--codebook_size', type=int, default=8)



    # Add more arguments as needed
    generated_images=[]
    
    args = parser.parse_args()

    args.tokenizer = VideoLLaMAConfig.get_tokenizer_config()
    args.llama = VideoLLaMAConfig.get_default_config()
    args.tokenizer.vocab_file = args.vocab_file

    # cvivit = CViViT_single_linear_nsvq_3(
    #     dim = 1024,
    #     quant_dim=32,
    #     codebook_size = args.codebook_size, #32,
    #     image_size = 256,
    #     patch_size = 32,
    #     temporal_patch_size = 2,
    #     spatial_depth = 8,
    #     temporal_depth = 8,
    #     dim_head = 64,
    #     heads = 16,
    #     lookup_free_quantization=False,
    #     use_vgg_and_gan=False,
    #     code_seq_len=args.tokens_per_delta,
    # ).cuda()
    # print("loading")

    # cvivit.load(args.load_cvivit_checkpoint)
    # print("loaded")

    # load jsonl file
    json_obj = []
    with open('/root/data/data_0919/data/bridge_window3_codebook8_codeseqlen4_unshuffled_filtered.jsonl', 'r') as f:
        for line in f:
            json_obj.append(json.loads(line))
    
    JaxDistributedConfig.initialize(args.jax_distributed)
    set_random_seed(args.seed)

    lwm = LWMInference(
        image_size=256,
        # codebook_size=args.codebook_size,
        tokens_per_delta=args.tokens_per_delta,
        vqgan_checkpoint=args.vqgan_checkpoint,
        vocab_file=args.vocab_file,
        multi_image=args.multi_image,
        jax_distributed=args.jax_distributed,
        seed=args.seed,
        mesh_dim=args.mesh_dim,
        dtype=args.dtype,
        load_llama_config=args.load_llama_config,
        update_llama_config=args.update_llama_config,
        load_checkpoint=args.load_checkpoint,
        tokenizer=args.tokenizer,
        llama=args.llama
    )

    last_index = 50000
    indices_list = []
    # os.makedirs(f'inference_img', exist_ok=True)
    for index, json_elem in enumerate(json_obj):
        # if json_elem['image'][0] != '/home/t-sye/bridge_img/166.jpg':
        if index != 1241:
            continue
        if index > last_index:
            break
        
        first_image = json_elem['image'][0]
        next_image = json_elem['image'][1]
        print(json_elem)
        # first_image = '/root/data/bridge_img//166.jpg'
        first_image = first_image.replace('/home/t-sye', '/root/data')
        next_image = next_image.replace('/home/t-sye', '/root/data')
        instruction = json_elem['instruction'].replace('<s> You are a helpful assistant. USER: ', "").replace(' ASSISTANT:', "")
        image_1 = Image.open(first_image)
        image_2 = Image.open(next_image)

        # image_1 = transform(image_1)
        # img1 = image_1.unsqueeze(1).cuda()
        # img2 = transform(image_2).unsqueeze(1)
        # img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()
        image = np.array(image_1)

        
        indices = lwm.inference(image, instruction)
        print("indices", indices)
        exit()
        indices_list.append({"image": json_elem['image'], "latent_action": indices.tolist()})

        # video = cvivit.inference(img, user_action_token_num=indices)
        # video = torch.clamp(video, 0.0, 1.0)
        # video *= 255.0
        # video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
        # print("video", video[0].shape)  
        # video = [Image.fromarray(image) for image in video]
        # generated_images.append(video[0])

        # total_image = Image.new('RGB', (256*3, 256))
        # x_offset = 0
        # for im in generated_images:
        #     total_image.paste(im, (x_offset,0))
        #     x_offset += 256
        # total_image.save(f'inference_img/test_{index}.jpg')

    # save the list of lists to a jsonl file
    # with open('/root/World-Model/data/latentvla_openx_inference.jsonl', 'w') as f:
    #     for item in indices_list:
    #         f.write(json.dumps(item))
    #         f.write("\n")
    
    
    # first_image = json_obj[0]['image'][0]
    # # TODO: load next image
    # next_image = json_obj[0]['image'][1]
    # # print(json_obj[0])
    # instruction = json_obj[0]['instruction'].replace('<s> You are a helpful assistant. USER: ', "").replace(' ASSISTANT:', "")
    # image_1 = Image.open(first_image)
    # image_2 = Image.open(next_image)
    # image = np.array(image_1)
    # generated_images.append(image_1)
    # generated_images.append(image_2)

    # img = transform(image_1).unsqueeze(1)
    # img2 = transform(image_2).unsqueeze(1)
    # img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()




    # JaxDistributedConfig.initialize(args.jax_distributed)
    # set_random_seed(args.seed)

    # lwm = LWMInference(
    #     image_size=256,
    #     # codebook_size=args.codebook_size,
    #     tokens_per_delta=args.tokens_per_delta,
    #     vqgan_checkpoint=args.vqgan_checkpoint,
    #     vocab_file=args.vocab_file,
    #     multi_image=args.multi_image,
    #     jax_distributed=args.jax_distributed,
    #     seed=args.seed,
    #     mesh_dim=args.mesh_dim,
    #     dtype=args.dtype,
    #     load_llama_config=args.load_llama_config,
    #     update_llama_config=args.update_llama_config,
    #     load_checkpoint=args.load_checkpoint,
    #     tokenizer=args.tokenizer,
    #     llama=args.llama
    # )
    # indices = lwm.inference(image, instruction)
    # print("indices", indices)

#     video = cvivit.inference(img, user_action_token_num=indices)
#     video = torch.clamp(video, 0.0, 1.0)
#     video *= 255.0
#     video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
#     print("video", video[0].shape)
#     video = [Image.fromarray(image) for image in video]
#     generated_images.append(video[0])

#     total_image = Image.new('RGB', (256*3, 256))
#     x_offset = 0
#     for im in generated_images:
#         total_image.paste(im, (x_offset,0))
#         x_offset += 256
#     total_image.save(f'test.jpg')

    
# # python -m lwm.latent_lwm_checking --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=2,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "params::checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"