import argparse
from collections import deque
from typing import Optional, Sequence
import numpy as np
from PIL import Image
import csv
from absl import flags
import json
from lwm.delta_sampler_override_bridge import DeltaSampler
from lwm.delta_llama import VideoLLaMAConfig

from tux import define_flags_with_default, JaxDistributedConfig, set_random_seed
# import torch
# from torchvision import transforms as T, utils
import sys

# sys.path.append('/home/t-sye/World-Model/Phenaki')
# from phenaki_pytorch import CViViT_single_linear_nsvq_3


# image_size = 256
# transform = T.Compose([
#             T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#             T.Resize((image_size, image_size)),
#             # T.CenterCrop(image_size),
#             T.ToTensor()
#         ])



class FLAGSClass:
    def __init__(self, flag_dict):
        for key, value in flag_dict.items():
            setattr(self, key, value)

class LWMInference:
    def __init__(
        self,
        image_size: int = 256,
        **kwargs,
    ) -> None:
        print("kwargs", kwargs)
        flags = FLAGSClass(kwargs)

        self.model = DeltaSampler(FLAGS=flags)
        self.image_size = image_size
        self.tokens_per_delta = kwargs['tokens_per_delta']
        self.task_description = None

    def inference(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
        assert image.dtype == np.uint8
        image = Image.fromarray(image)
        prompts = [{'image': [image], 'question': task_description}]
        
        user_action_token_num, vision_output, text_output = self.model(prompts)
        norm_raw_actions = user_action_token_num[0]
        print("norm raw actions", norm_raw_actions)
        indices = norm_raw_actions

        return indices


# load_checkpoint "params::/home/t-sye/World-Model/checkpoints/bridge_carrot_256_batch_128_multitask_40traj/streaming_params_${checkpoint}"
# update-llama-config "dict(action_vocab_size=245,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \

# python -m lwm.latent_lwm_checking --load_cvivit_checkpoint "/home/t-sye/World-Model/phenaki/vast_ai_backup/whole_window3/ddp_gpu_2_batch224_layer8_seq4_code2/vae.20000.pt" --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=2,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LWM Inference script")
    parser.add_argument('--tokens_per_delta', type=int, default=1, help='Tokens per delta')
    parser.add_argument('--vqgan_checkpoint', type=str, default="/home/t-sye/World-Model/checkpoints/lwm_checkpoints/vqgan")
    parser.add_argument('--vocab_file', type=str, default='/home/t-sye/World-Model/checkpoints/lwm_checkpoints/tokenizer.model')
    parser.add_argument('--multi_image', type=int, default=1)
    parser.add_argument('--jax_distributed', type=dict, default=JaxDistributedConfig.get_default_config())
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--mesh_dim', type=str, default="1,-1,1,1")
    parser.add_argument('--dtype', type=str, default="bf16")
    parser.add_argument('--load_llama_config', type=str, default="7b")
    parser.add_argument('--update_llama_config', type=str, default="")
    parser.add_argument('--load_checkpoint', type=str, default="")



    # Add more arguments as needed
    generated_images=[]
    
    args = parser.parse_args()

    args.tokenizer = VideoLLaMAConfig.get_tokenizer_config()
    args.llama = VideoLLaMAConfig.get_default_config()
    args.tokenizer.vocab_file = args.vocab_file


    json_obj = []
    with open('/home/t-sye/World-Model/data/817_multiobject_sink_rand_location_llava_train.jsonl', 'r') as f:
        for line in f:
            json_obj.append(json.loads(line))
    
    JaxDistributedConfig.initialize(args.jax_distributed)
    set_random_seed(args.seed)

    lwm = LWMInference(
        image_size=256,
        # codebook_size=args.codebook_size,
        tokens_per_delta=args.tokens_per_delta,
        vqgan_checkpoint=args.vqgan_checkpoint,
        vocab_file=args.vocab_file,
        multi_image=args.multi_image,
        jax_distributed=args.jax_distributed,
        seed=args.seed,
        mesh_dim=args.mesh_dim,
        dtype=args.dtype,
        load_llama_config=args.load_llama_config,
        update_llama_config=args.update_llama_config,
        load_checkpoint=args.load_checkpoint,
        tokenizer=args.tokenizer,
        llama=args.llama
    )

    last_index = len(json_obj)
    indices_list = []
    for index, json_elem in enumerate(json_obj):
        if index > last_index:
            break
        
        absolute_path = '/home/t-sye/World-Model/'
        first_image = absolute_path + json_elem['image']
        first_image = first_image.replace('multiobject_sink_rand_location_blocking_llava', 'multiobject_sink_rand_location_llava')
        # next_image = json_elem['image'][1]
        instruction = json_elem['instruction'].replace('<s> You are a helpful assistant. USER: ', "").replace(' ASSISTANT:', "")
        image_1 = Image.open(first_image)
        # image_2 = Image.open(next_image)
        image = np.array(image_1)

        
        indices = lwm.inference(image, instruction)
        print("indices", indices)
        indices_list.append({"image": json_elem['image'], "latent_action": indices.tolist(), "action": json_elem['action'], "raw_actions": json_elem['raw_actions'],"instruction": json_elem['instruction']})

    # save the list of lists to a jsonl file
    with open('/home/t-sye/World-Model/data/817_multiobject_sink_rand_location_llava_train_openx_c8_s4_150K_latent_inference_trained.jsonl', 'w') as f:
        for item in indices_list:
            f.write(json.dumps(item))
            f.write("\n")
    
    
    # first_image = json_obj[0]['image'][0]
    # # TODO: load next image
    # next_image = json_obj[0]['image'][1]
    # # print(json_obj[0])
    # instruction = json_obj[0]['instruction'].replace('<s> You are a helpful assistant. USER: ', "").replace(' ASSISTANT:', "")
    # image_1 = Image.open(first_image)
    # image_2 = Image.open(next_image)
    # image = np.array(image_1)
    # # generated_images.append(image_1)
    # # generated_images.append(image_2)

    # # img = transform(image_1).unsqueeze(1)
    # # img2 = transform(image_2).unsqueeze(1)
    # # img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()




    # JaxDistributedConfig.initialize(args.jax_distributed)
    # set_random_seed(args.seed)

    # lwm = LWMInference(
    #     image_size=256,
    #     # codebook_size=args.codebook_size,
    #     tokens_per_delta=args.tokens_per_delta,
    #     vqgan_checkpoint=args.vqgan_checkpoint,
    #     vocab_file=args.vocab_file,
    #     multi_image=args.multi_image,
    #     jax_distributed=args.jax_distributed,
    #     seed=args.seed,
    #     mesh_dim=args.mesh_dim,
    #     dtype=args.dtype,
    #     load_llama_config=args.load_llama_config,
    #     update_llama_config=args.update_llama_config,
    #     load_checkpoint=args.load_checkpoint,
    #     tokenizer=args.tokenizer,
    #     llama=args.llama
    # )
    # indices = lwm.inference(image, instruction)
    # print("indices", indices)

    # # video = cvivit.inference(img, user_action_token_num=indices)
    # # video = torch.clamp(video, 0.0, 1.0)
    # # video *= 255.0
    # # video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
    # # print("video", video[0].shape)
    # # video = [Image.fromarray(image) for image in video]
    # # generated_images.append(video[0])

    # # total_image = Image.new('RGB', (256*3, 256))
    # # x_offset = 0
    # # for im in generated_images:
    # #     total_image.paste(im, (x_offset,0))
    # #     x_offset += 256
    # # total_image.save(f'test.jpg')

    
# python -m lwm.latent_lwm_checking --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=8,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "params::checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"