import argparse
from collections import deque
from typing import Optional, Sequence
import numpy as np
from PIL import Image
import csv
from absl import flags
import json
from lwm.delta_sampler_override_bridge import DeltaSampler
from lwm.delta_llama import VideoLLaMAConfig

from tux import define_flags_with_default, JaxDistributedConfig, set_random_seed
# import torch
# from torchvision import transforms as T, utils
import sys

# sys.path.append('/root/World-Model/Phenaki')
# from phenaki_pytorch import CViViT_single_linear_nsvq_3


# image_size = 256
# transform = T.Compose([
#             T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#             T.Resize((image_size, image_size)),
#             # T.CenterCrop(image_size),
#             T.ToTensor()
#         ])



class FLAGSClass:
    def __init__(self, flag_dict):
        for key, value in flag_dict.items():
            setattr(self, key, value)

class LWMInference:
    def __init__(
        self,
        image_size: int = 256,
        **kwargs,
    ) -> None:
        print("kwargs", kwargs)
        flags = FLAGSClass(kwargs)

        self.model = DeltaSampler(FLAGS=flags)
        self.image_size = image_size
        self.tokens_per_delta = kwargs['tokens_per_delta']
        self.task_description = None

    def inference(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
        assert image.dtype == np.uint8
        image = Image.fromarray(image)
        prompts = [{'image': [image], 'question': task_description}]
        
        user_action_token_num, vision_output, text_output = self.model(prompts)
        norm_raw_actions = user_action_token_num[0]
        print("norm raw actions", norm_raw_actions)
        indices = norm_raw_actions

        return indices


# load_checkpoint "params::/home/t-sye/World-Model/checkpoints/bridge_carrot_256_batch_128_multitask_40traj/streaming_params_${checkpoint}"
# update-llama-config "dict(action_vocab_size=245,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \

# python -m lwm.latent_lwm_checking --load_cvivit_checkpoint "/home/t-sye/World-Model/phenaki/vast_ai_backup/whole_window3/ddp_gpu_2_batch224_layer8_seq4_code2/vae.20000.pt" --tokens_per_delta 4 --update_llama_config "dict(delta_vocab_size=2,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" --load_checkpoint "checkpoints/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LWM Inference script")
    parser.add_argument('--tokens_per_delta', type=int, default=4, help='Tokens per delta')
    parser.add_argument('--vqgan_checkpoint', type=str, default="/root/checkpoints/lwm_checkpoints/vqgan")
    parser.add_argument('--vocab_file', type=str, default='/root/checkpoints/lwm_checkpoints/tokenizer.model')
    parser.add_argument('--multi_image', type=int, default=1)
    parser.add_argument('--jax_distributed', type=dict, default=JaxDistributedConfig.get_default_config())
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--mesh_dim', type=str, default="1,-1,1,1")
    parser.add_argument('--dtype', type=str, default="bf16")
    parser.add_argument('--load_llama_config', type=str, default="7b")
    parser.add_argument('--update_llama_config', type=str, default="")
    parser.add_argument('--load_checkpoint', type=str, default="")
    parser.add_argument('--codebook_size', type=int, default=8)



    # Add more arguments as needed
    generated_images=[]
    
    args = parser.parse_args()

    args.tokenizer = VideoLLaMAConfig.get_tokenizer_config()
    args.llama = VideoLLaMAConfig.get_default_config()
    args.tokenizer.vocab_file = args.vocab_file

    
    
    JaxDistributedConfig.initialize(args.jax_distributed)
    set_random_seed(args.seed)

    lwm = LWMInference(
        image_size=256,
        # codebook_size=args.codebook_size,
        tokens_per_delta=args.tokens_per_delta,
        vqgan_checkpoint=args.vqgan_checkpoint,
        vocab_file=args.vocab_file,
        multi_image=args.multi_image,
        jax_distributed=args.jax_distributed,
        seed=args.seed,
        mesh_dim=args.mesh_dim,
        dtype=args.dtype,
        load_llama_config=args.load_llama_config,
        update_llama_config=args.update_llama_config,
        load_checkpoint=args.load_checkpoint,
        tokenizer=args.tokenizer,
        llama=args.llama
    )

    last_index = 50000
    indices_list = []
    # os.makedirs(f'inference_img', exist_ok=True)
    
        
    first_image = '/root/World-Model/lwm/analysis/test8.jpg'
    instruction ="take broccoli out of pan"
    
    image_1 = Image.open(first_image)

    # image_1 = transform(image_1)
    # img1 = image_1.unsqueeze(1).cuda()
    # img2 = transform(image_2).unsqueeze(1)
    # img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()
    image = np.array(image_1)

    
    indices = lwm.inference(image, instruction)
    print("indices", indices)
    exit()
    indices_list.append({"image": json_elem['image'], "latent_action": indices.tolist()})

