from absl.app import run
import math
from tqdm import tqdm
from PIL import Image
import decord
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from transformers import GenerationConfig
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, open_file
)
# from lwm.delta_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.delta_llama_action import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.vqgan import VQGAN
import albumentations
import jax.numpy as jnp
import cv2
import sys
from einops import rearrange, repeat, pack, unpack
from jax import lax, random

def compute_eos_positions(output, input_ids_shape, eos_token_id, slice_start, slice_size):
    # Use lax.dynamic_slice for dynamic slicing
    text_output = lax.dynamic_slice(output, (0, slice_start), (output.shape[0], slice_size))
    eos_index = jnp.where(text_output == eos_token_id)[0]
    if eos_index.size == 0:
        eos_index = jnp.array([text_output.size])  # Handle case where eos_token_id is not found
    else:
        eos_index = eos_index[0]
    eos_positions = jnp.zeros_like(text_output, dtype=bool)
    eos_positions = eos_positions.at[eos_index:].set(True)
    return text_output, eos_positions

class DeltaActionSampler:
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        print("FLAGS", FLAGS.tokenizer)
        self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
        self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
        self.prefix_tokenizer = VideoLLaMAConfig.get_tokenizer(
            FLAGS.tokenizer, truncation_side='left', padding_side='left'
        )
        self.tokenizer = VideoLLaMAConfig.get_tokenizer(FLAGS.tokenizer)
        # self.tokens_per_delta = 64
        self.min_buffer_size = 256
        self.sharded_rng = next_rng()
        self._load_model()


    @property
    def block_size(self):
        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
    
    @property
    def data_dim(self):
        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']

    def _process_frame(self, images, size):
        preprocessor_finetune = albumentations.Compose([
                albumentations.LongestMaxSize(max_size=256),  # Resize the longest side to 256
                # albumentations.PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=(0, 0, 0))
                albumentations.Resize(256, 256), 
            ])
        image_vqgan_list = []
        processed_image = []
        for image in images:
            img_array = np.array(image).astype(np.uint8)
                
            image_vqgan = preprocessor_finetune(image=img_array)["image"]
            # TODO: check this?
            image_vqgan2 = (image_vqgan/255).astype(np.float32)
            image_vqgan = (image_vqgan/127.5 - 1.0).astype(np.float32)
            # print(image_vqgan2.max(), image_vqgan2.min())
            processed_image.append(image_vqgan2)
            image_vqgan_list.append(image_vqgan[None])
            # image_vqgan = torch.tensor(image_vqgan.transpose(2,0,1)[None]).to(dtype=torch.float32)
        image_vqgan_list = np.concatenate(image_vqgan_list, axis=0)
        return image_vqgan_list, processed_image


    def _read_process_vision(self, images):

        # path = '/mnt/sda/anon/World-As-Code/llava/playground/data/lang_table_separate_100/1.jpg'
        # f = open_file(path, 'rb')
        # if path.endswith('.png') or path.endswith('.jpg'):
        #     image = Image.open(f)
 
        vision, processed_image = self._process_frame(images, 256)
        
        B = 1
        encodings = []
        for i in range(0, len(vision), 1):
            v = vision[i:i + B]
            if len(v) % B == 0:
                n_pad = 0
            else:
                n_pad = B - len(v) % B
            v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
            enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
            # print("enc", enc)
            enc = enc[n_pad:]
            for t in range(len(enc)):
                encodings.extend(enc[t].reshape(-1).tolist())
        return encodings, processed_image



    def construct_input(self, prompts):
        for i, prompt in enumerate(prompts):
            vision, processed_image = self._read_process_vision(prompt['image'])
            tokens, vm = [], []
            tokens.extend(vision)
            vm.extend([True] * len(vision))
            tokens.extend([8193])
            vm.extend([True] * len([8193]))
        return {
            'input_ids': np.expand_dims(tokens, axis=0),
            'processed_image': processed_image,
        }
             

    def _load_model(self):
        if self.FLAGS.load_llama_config != '':
            llama_config = VideoLLaMAConfig.load_config(self.FLAGS.load_llama_config)
            updates = VideoLLaMAConfig(**self.FLAGS.llama)
            llama_config.update(dict(
                remat_block=updates.remat_block,
                remat_attention=updates.remat_attention,
                remat_mlp=updates.remat_mlp,
                scan_attention=updates.scan_attention,
                scan_mlp=updates.scan_mlp,
                scan_query_chunk_size=updates.scan_query_chunk_size,
                scan_key_chunk_size=updates.scan_key_chunk_size,
                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
                scan_layers=updates.scan_layers,
                param_scan_axis=updates.param_scan_axis,
            ))
        else:
            llama_config = VideoLLaMAConfig(**self.FLAGS.llama)




        if self.FLAGS.update_llama_config != '':
            llama_config.update(dict(eval(self.FLAGS.update_llama_config)))

        llama_config.update(dict(
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        ))
        llama_config.update(dict(mesh_dim=self.FLAGS.mesh_dim))
        self.config = llama_config

        with jax.default_device(jax.devices("cpu")[0]):
            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
                    self.FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
            )
            self.model = FlaxVideoLLaMAForCausalLM(
                llama_config, 
                input_shape=(512, 8192), 
                seed=self.FLAGS.seed, 
                _do_init=False,
                dtype=get_float_dtype_by_name(self.FLAGS.dtype),
            )

        
            self.model_ps = match_partition_rules(
                VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
            )
            shard_fns, _ = make_shard_and_gather_fns(
                self.model_ps, get_float_dtype_by_name(self.FLAGS.dtype)
            )

            with self.mesh:
                self.params = tree_apply(shard_fns, self.params)

    @cached_property
    def _forward_generate(self):
        def fn(params, rng, batch, n_tokens):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)

            self.model.config.sample_mode='text'
            output = self.model.generate(
                batch['input_ids'],
                vision_masks=batch['vision_masks'],
                attention_mask=batch['attention_mask'],
                delta_masks=batch['delta_masks'],
                action_masks=batch['action_masks'],
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=self.block_size,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            ).sequences

            # eos_token_id = jnp.array(8193)
            eos_token_id = self.tokenizer.eos_token_id
            eos_token_id = jax.device_get(eos_token_id)

            # text_output= output[:,batch['input_ids'].shape[1]:]
            text_output= jnp.array(output[:,batch['input_ids'].shape[1]:])
            text_output = jax.device_get(text_output)
            text_output_array = np.array(text_output)
            # text_output = output
            # input_ids_shape = batch['input_ids'].shape[1]
            print("eos_token_id", eos_token_id, text_output)


            bool_array = jnp.zeros_like(text_output, dtype=bool)

            # Find the index where the value of text_output is zero
            zero_indices = np.where(text_output_array == 0)[0]

            # If there are zero values in the text_output array, update the bool_array
            if zero_indices.size > 0:
                first_zero_index = zero_indices[0]
                bool_array = bool_array.at[first_zero_index:].set(True)



            # Compute slice parameters outside the JIT function
            # slice_start = int(input_ids_shape)
            # slice_size = int(output.shape[1] - input_ids_shape)

            # Define a wrapper function to call compute_eos_positions with proper JIT
            # compute_eos_positions_jit = jax.jit(compute_eos_positions)
            #             # Call the JIT-compiled function
            # # Call the JIT-compiled function
            # text_output, eos_positions = compute_eos_positions_jit(output, input_ids_shape, eos_token_id, slice_start, slice_size)


            # eos_index = jnp.where(text_output == eos_token_id)
            # eos_index = eos_index[0]
            # eos_positions = jnp.zeros_like(text_output, dtype=bool)
            # eos_positions[eos_index:] = True


            
            # print(text_output)

            prefix_token = self.prefix_tokenizer(
                '<action>',
                return_tensors='np'
            )
            prefix_token = jnp.array(prefix_token['input_ids'])

            # text_output = jax.device_get(text_output)

            print(output.shape, text_output.shape, prefix_token.shape)

            action_input = jnp.concatenate([output, prefix_token], axis=1)

            action_vision_masks = jnp.concatenate([jnp.expand_dims(batch['vision_masks'][0], axis=0),jnp.zeros(jnp.expand_dims(text_output[0], axis=0).shape, dtype=bool)], axis=1)
            action_vision_masks = jnp.concatenate([action_vision_masks, jnp.zeros(prefix_token.shape, dtype=bool)], axis=1)
            action_attn_mask = jnp.concatenate([jnp.expand_dims(batch['attention_mask'][0], axis=0), jnp.ones(jnp.expand_dims(text_output[0], axis=0).shape, dtype= batch['attention_mask'].dtype)], axis=1)
            action_attn_mask = jnp.concatenate([action_attn_mask, jnp.ones(prefix_token.shape, dtype=bool)], axis=1)
            action_delta_masks = jnp.concatenate([jnp.expand_dims(batch['delta_masks'][0], axis=0),jnp.zeros(jnp.expand_dims(text_output[0], axis=0).shape, dtype=bool)], axis=1)
            action_delta_masks = jnp.concatenate([action_delta_masks, jnp.zeros(prefix_token.shape, dtype=bool)], axis=1)
            action_action_masks = jnp.concatenate([jnp.expand_dims(batch['action_masks'][0], axis=0),jnp.zeros(jnp.expand_dims(text_output[0], axis=0).shape, dtype=bool)], axis=1)
            action_action_masks = jnp.concatenate([action_action_masks, jnp.zeros(prefix_token.shape, dtype=bool)], axis=1)
            
            # eos_mask = jnp.concatenate([batch['attention_mask'][0][None, :], ~eos_positions], axis=1)
            # eos_mask = jnp.concatenate([eos_mask, jnp.ones(prefix_token.shape, dtype=bool)], axis=1)
            # action_attn_mask = action_attn_mask * eos_mask
            action_attn_mask_values = jax.device_get(action_attn_mask)

            print(action_attn_mask_values)



            self.model.config.sample_mode='action'
            action_output = self.model.generate_vision(
                batch['input_ids'],
                vision_masks=batch['vision_masks'],
                attention_mask=batch['attention_mask'],
                delta_masks=batch['delta_masks'],
                action_masks=batch['action_masks'],
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=n_tokens,
                    min_new_tokens=n_tokens,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            ).sequences
            action_output= action_output[:,batch['input_ids'].shape[1]:]
            vision_output = None
            return text_output, action_output, action_attn_mask_values, rng_generator()
        
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS()),
            out_shardings=(PS(), PS(), PS(), PS()),
            static_argnums=(3,)
        )

    
    def generate_video_pred(self, prompts, images, max_input_length, processed_image):
        
        sharded_rng = next_rng()
        # print("image shape", images.shape)

        # ***********
        image = images
        # images = np.concatenate([images, images], axis=0)
        # uncond_prompts = ["<s><vision>"] * len(prompts)
        # prompts = prompts + uncond_prompts
        # # print("prompts", prompts)
        inputs = self.prefix_tokenizer(
            prompts,
            padding='max_length',
            truncation=True,
            max_length=max_input_length,
            return_tensors='np'
        )
        # print(inputs.input_ids.shape)
        # prefix_for_gen = ["</vision>  ASSISTANT:"] * len(prompts)
        prefix_for_gen = ["</vision>"] * len(prompts)
        # prefix_for_gen = ["</vision><delta>"] * len(prompts)
        inputs_for_gen = self.prefix_tokenizer(
            prefix_for_gen,
            return_tensors='np'
        )
        # print("inputs for gen shape", inputs_for_gen.input_ids.shape)


        batch = dict(
            input_ids=np.concatenate([inputs.input_ids, images, inputs_for_gen.input_ids], axis=1),
            attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype), inputs_for_gen.attention_mask], axis=1),
            vision_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.ones(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool)
            ], axis=1),
            delta_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.zeros(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
            ], axis=1),
            action_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.zeros(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
            ], axis=1),
        )
        # print("batch", batch["input_ids"])

        with self.mesh:
            text_output, action_output, eos_positions, sharded_rng = self._forward_generate(
                self.params, sharded_rng, batch, 
                self.FLAGS.tokens_per_action
            )
            text_output = jax.device_get(text_output)
            action_output = jax.device_get(action_output)
            eos_positions = jax.device_get(eos_positions)

            print("text output and corresponding eos output", text_output, eos_positions)

        return None, action_output, None
            # **************

    def __call__(self, prompts):
        batch = self.construct_input(prompts)
        # prompts = []
        print("question", prompts[0]['question'])
        text_prompt = f"<s>You are a helpful assistant. USER: What action should the robot take to `{prompts[0]['question']}` ASSISTANT: <vision>"
        user_action_token_num, action_output, vision_output = self.generate_video_pred(prompts=[text_prompt], images=batch['input_ids'], max_input_length=128, processed_image=batch["processed_image"])
        output_text = []
        # if text_output is not None:
        #     # print("text_output", list(self.tokenizer.batch_decode(text_output)))
        #     for text in list(self.tokenizer.batch_decode(text_output, skip_special_tokens=True)):
        #         if self.tokenizer.eos_token in text:
        #             text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
        #         output_text.append(text)
        # else: 
        #     output_text = ["Do not move. 10  10"]
        print("action output", action_output)
        return user_action_token_num, vision_output, action_output
        # return vision_output
        