from absl.app import run
import math
from tqdm import tqdm
from PIL import Image
import decord
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from transformers import GenerationConfig
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, open_file
)
# from lwm.delta_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.delta_llama_action import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.vqgan import VQGAN
import albumentations
import jax.numpy as jnp
import cv2
import sys
from einops import rearrange, repeat, pack, unpack
from jax import lax, random

def compute_eos_positions(output, input_ids_shape, eos_token_id, slice_start, slice_size):
    # Use lax.dynamic_slice for dynamic slicing
    text_output = lax.dynamic_slice(output, (0, slice_start), (output.shape[0], slice_size))
    eos_index = jnp.where(text_output == eos_token_id)[0]
    if eos_index.size == 0:
        eos_index = jnp.array([text_output.size])  # Handle case where eos_token_id is not found
    else:
        eos_index = eos_index[0]
    eos_positions = jnp.zeros_like(text_output, dtype=bool)
    eos_positions = eos_positions.at[eos_index:].set(True)
    return text_output, eos_positions

class DeltaActionSampler:
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        print("FLAGS", FLAGS.tokenizer)
        self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
        self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
        self.prefix_tokenizer = VideoLLaMAConfig.get_tokenizer(
            FLAGS.tokenizer, truncation_side='left', padding_side='left'
        )
        self.tokenizer = VideoLLaMAConfig.get_tokenizer(FLAGS.tokenizer)
        # self.tokens_per_delta = 64
        self.min_buffer_size = 256
        self.sharded_rng = next_rng()
        self._load_model()


    @property
    def block_size(self):
        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
    
    @property
    def data_dim(self):
        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']

    def _process_frame(self, images, size):
        preprocessor_finetune = albumentations.Compose([
                albumentations.LongestMaxSize(max_size=256),  # Resize the longest side to 256
                # albumentations.PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=(0, 0, 0))
                albumentations.Resize(256, 256), 
            ])
        image_vqgan_list = []
        processed_image = []
        for image in images:
            img_array = np.array(image).astype(np.uint8)
                
            image_vqgan = preprocessor_finetune(image=img_array)["image"]
            # TODO: check this?
            image_vqgan2 = (image_vqgan/255).astype(np.float32)
            image_vqgan = (image_vqgan/127.5 - 1.0).astype(np.float32)
            # print(image_vqgan2.max(), image_vqgan2.min())
            processed_image.append(image_vqgan2)
            image_vqgan_list.append(image_vqgan[None])
            # image_vqgan = torch.tensor(image_vqgan.transpose(2,0,1)[None]).to(dtype=torch.float32)
        image_vqgan_list = np.concatenate(image_vqgan_list, axis=0)
        return image_vqgan_list, processed_image


    def _read_process_vision(self, images):

        # path = '/mnt/sda/anon/World-As-Code/llava/playground/data/lang_table_separate_100/1.jpg'
        # f = open_file(path, 'rb')
        # if path.endswith('.png') or path.endswith('.jpg'):
        #     image = Image.open(f)
 
        vision, processed_image = self._process_frame(images, 256)
        
        B = 1
        encodings = []
        for i in range(0, len(vision), 1):
            v = vision[i:i + B]
            if len(v) % B == 0:
                n_pad = 0
            else:
                n_pad = B - len(v) % B
            v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
            enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
            # print("enc", enc)
            enc = enc[n_pad:]
            for t in range(len(enc)):
                encodings.extend(enc[t].reshape(-1).tolist())
        return encodings, processed_image



    def construct_input(self, prompts):
        for i, prompt in enumerate(prompts):
            vision, processed_image = self._read_process_vision(prompt['image'])
            tokens, vm = [], []
            tokens.extend(vision)
            vm.extend([True] * len(vision))
            tokens.extend([8193])
            vm.extend([True] * len([8193]))
        return {
            'input_ids': np.expand_dims(tokens, axis=0),
            'processed_image': processed_image,
        }
             

    def _load_model(self):
        if self.FLAGS.load_llama_config != '':
            llama_config = VideoLLaMAConfig.load_config(self.FLAGS.load_llama_config)
            updates = VideoLLaMAConfig(**self.FLAGS.llama)
            llama_config.update(dict(
                remat_block=updates.remat_block,
                remat_attention=updates.remat_attention,
                remat_mlp=updates.remat_mlp,
                scan_attention=updates.scan_attention,
                scan_mlp=updates.scan_mlp,
                scan_query_chunk_size=updates.scan_query_chunk_size,
                scan_key_chunk_size=updates.scan_key_chunk_size,
                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
                scan_layers=updates.scan_layers,
                param_scan_axis=updates.param_scan_axis,
            ))
        else:
            llama_config = VideoLLaMAConfig(**self.FLAGS.llama)




        if self.FLAGS.update_llama_config != '':
            llama_config.update(dict(eval(self.FLAGS.update_llama_config)))

        llama_config.update(dict(
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        ))
        llama_config.update(dict(mesh_dim=self.FLAGS.mesh_dim))
        self.config = llama_config

        with jax.default_device(jax.devices("cpu")[0]):
            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
                    self.FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
            )
            self.model = FlaxVideoLLaMAForCausalLM(
                llama_config, 
                input_shape=(512, 8192), 
                seed=self.FLAGS.seed, 
                _do_init=False,
                dtype=get_float_dtype_by_name(self.FLAGS.dtype),
            )

        
            self.model_ps = match_partition_rules(
                VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
            )
            shard_fns, _ = make_shard_and_gather_fns(
                self.model_ps, get_float_dtype_by_name(self.FLAGS.dtype)
            )

            with self.mesh:
                self.params = tree_apply(shard_fns, self.params)

    @cached_property
    def _forward_generate(self):
        def fn(params, rng, batch, n_tokens):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)

            self.model.config.sample_mode='text'
            output = self.model.generate(
                batch['input_ids'],
                vision_masks=batch['vision_masks'],
                attention_mask=batch['attention_mask'],
                delta_masks=batch['delta_masks'],
                action_masks=batch['action_masks'],
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=self.block_size,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            ).sequences

            # eos_token_id = jnp.array(8193)
            eos_token_id = self.tokenizer.eos_token_id
            eos_token_id = jax.device_get(eos_token_id)

            # text_output= output[:,batch['input_ids'].shape[1]:]
            text_output= jnp.array(output[:,batch['input_ids'].shape[1]:])
            text_output = jax.device_get(text_output)
            # text_output_array = np.array(text_output)
            # text_output = output
            # input_ids_shape = batch['input_ids'].shape[1]
            print("eos_token_id", eos_token_id, text_output)

            return text_output, rng_generator()
        
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS()),
            out_shardings=(PS(), PS()),
            static_argnums=(3,)
        )
    
    @cached_property
    def _forward_action_generate(self):
        def fn(params, rng, batch, n_tokens, prefix_token, text_output):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)
            # print("text output", jnp.array(text_output))

            action_vision_masks = jnp.concatenate([jnp.expand_dims(batch['vision_masks'][0], axis=0),jnp.zeros(jnp.array(text_output).shape, dtype=bool)], axis=1)
            action_vision_masks = jnp.concatenate([action_vision_masks, jnp.zeros(prefix_token.shape, dtype=bool)], axis=1)
            action_attn_mask = jnp.concatenate([jnp.expand_dims(batch['attention_mask'][0], axis=0), jnp.ones(jnp.array(text_output).shape, dtype= batch['attention_mask'].dtype)], axis=1)
            action_attn_mask = jnp.concatenate([action_attn_mask, jnp.ones(prefix_token.shape, dtype=bool)], axis=1)
            action_delta_masks = jnp.concatenate([jnp.expand_dims(batch['delta_masks'][0], axis=0),jnp.zeros(jnp.array(text_output).shape, dtype=bool)], axis=1)
            action_delta_masks = jnp.concatenate([action_delta_masks, jnp.zeros(prefix_token.shape, dtype=bool)], axis=1)
            action_action_masks = jnp.concatenate([jnp.expand_dims(batch['action_masks'][0], axis=0),jnp.zeros(jnp.array(text_output).shape, dtype=bool)], axis=1)
            action_action_masks = jnp.concatenate([action_action_masks, jnp.zeros(prefix_token.shape, dtype=bool)], axis=1)
            action_input_ids = jnp.concatenate([jnp.expand_dims(batch['input_ids'][0], axis=0), jnp.array(text_output)], axis=1)
            action_input_ids = jnp.concatenate([action_input_ids, prefix_token], axis=1)
            # print("concat", batch['attention_mask'][0][None, :].shape, jnp.array(eos_positions).shape)
            
            # eos_mask = jnp.concatenate([batch['attention_mask'][0][None, :], jnp.array(eos_positions)], axis=1)
            # eos_mask = jnp.concatenate([eos_mask, jnp.ones(prefix_token.shape, dtype=bool)], axis=1)
            # action_attn_mask = action_attn_mask * eos_mask
            # action_attn_mask_values = jax.device_get(action_attn_mask)

            # print(action_attn_mask_values)
            # print(action_vision_masks, action_attn_mask, action_delta_masks, action_action_masks, action_input_ids)
            # print("input ids seq length", action_input_ids.shape)


            self.model.config.sample_mode='action'
            action_output = self.model.generate_vision(
                action_input_ids,
                vision_masks=action_vision_masks,
                attention_mask=action_attn_mask,
                delta_masks=action_delta_masks,
                action_masks=action_action_masks,
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=7,
                    min_new_tokens=7,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            ).sequences
            action_output= action_output[:,action_input_ids.shape[1]:]
            return action_output, rng_generator()
        
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS(), PS(), PS()),
            out_shardings=(PS(), PS()),
            static_argnums=(5,)
        )



    
    def generate_video_pred(self, prompts, images, max_input_length, processed_image):
        
        sharded_rng = next_rng()
        # print("image shape", images.shape)

        # ***********
        image = images
        # images = np.concatenate([images, images], axis=0)
        # uncond_prompts = ["<s><vision>"] * len(prompts)
        # prompts = prompts + uncond_prompts
        # # print("prompts", prompts)
        inputs = self.prefix_tokenizer(
            prompts,
            padding='max_length',
            truncation=True,
            max_length=max_input_length,
            return_tensors='np'
        )
        # print(inputs.input_ids.shape)
        # prefix_for_gen = ["</vision>  ASSISTANT:"] * len(prompts)
        prefix_for_gen = ["</vision>"] * len(prompts)
        # prefix_for_gen = ["</vision><delta>"] * len(prompts)
        inputs_for_gen = self.prefix_tokenizer(
            prefix_for_gen,
            return_tensors='np'
        )
        # print("inputs for gen shape", inputs_for_gen.input_ids.shape)


        batch = dict(
            input_ids=np.concatenate([inputs.input_ids, images, inputs_for_gen.input_ids], axis=1),
            attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype), inputs_for_gen.attention_mask], axis=1),
            vision_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.ones(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool)
            ], axis=1),
            delta_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.zeros(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
            ], axis=1),
            action_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.zeros(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
            ], axis=1),
        )
        # print("batch", batch["input_ids"])

        with self.mesh:
            text_output, sharded_rng = self._forward_generate(
                self.params, sharded_rng, batch, 
                self.FLAGS.tokens_per_action
            )

            text_output = jax.device_get(text_output)

            # Find the index where the value of text_output is zero
            # print(text_output)
            zero_indices = np.where(text_output[0] == 529)[0]
            text_output = text_output[:, :zero_indices[0]]

            # print(text_output)

            prefix_token = self.prefix_tokenizer(
                '<action>',
                return_tensors='np'
            )
            prefix_token = jnp.array(prefix_token['input_ids'])

            # bool_array = tuple(tuple(row) for row in bool_array.tolist())

            text_output = tuple(tuple(row) for row in text_output.tolist())

            # print("bool array", text_output)
    
            action_output, sharded_rng = self._forward_action_generate(
                self.params, sharded_rng, batch,
                self.FLAGS.tokens_per_action, prefix_token, text_output,
            )

            
            action_output = jax.device_get(action_output)
            # eos_positions = jax.device_get(eos_positions)

            # print("text output and corresponding eos output", text_output, eos_positions)

        return text_output, action_output, None
            # **************

    def __call__(self, prompts):
        batch = self.construct_input(prompts)
        # prompts = []
        print("question", prompts[0]['question'])
        text_prompt = f"<s>You are a helpful assistant. USER: What action should the robot take to `{prompts[0]['question']}` ASSISTANT: <vision>"
        text_output, action_output, vision_output = self.generate_video_pred(prompts=[text_prompt], images=batch['input_ids'], max_input_length=128, processed_image=batch["processed_image"])
        output_text = []
        if text_output is not None:
            # print("text_output", list(self.tokenizer.batch_decode(text_output)))
            for text in list(self.tokenizer.batch_decode(text_output, skip_special_tokens=True)):
                if self.tokenizer.eos_token in text:
                    text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
                output_text.append(text)
        else: 
            output_text = ["Do not move. 10  10"]
        print("text output", output_text)
        print("action output", action_output)
        return output_text, vision_output, action_output
        # return vision_output
        