from absl.app import run
import math
from tqdm import tqdm
from PIL import Image
import decord
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from transformers import GenerationConfig
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, open_file
)
# from lwm.delta_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.delta_llama_cont_action import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.vqgan import VQGAN
import albumentations
import jax.numpy as jnp
import cv2
import sys
from einops import rearrange, repeat, pack, unpack


class DeltaContActionSampler:
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        print("FLAGS", FLAGS.tokenizer)
        self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
        self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
        self.prefix_tokenizer = VideoLLaMAConfig.get_tokenizer(
            FLAGS.tokenizer, truncation_side='left', padding_side='left'
        )
        self.tokenizer = VideoLLaMAConfig.get_tokenizer(FLAGS.tokenizer)
        # self.tokens_per_delta = 64
        self.min_buffer_size = 256
        self.sharded_rng = next_rng()
        self._load_model()


    @property
    def block_size(self):
        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
    
    @property
    def data_dim(self):
        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']

    def _process_frame(self, images, size):
        preprocessor_finetune = albumentations.Compose([
                albumentations.LongestMaxSize(max_size=256),  # Resize the longest side to 256
                # albumentations.PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=(0, 0, 0))
                albumentations.Resize(256, 256), 
            ])
        image_vqgan_list = []
        processed_image = []
        for image in images:
            img_array = np.array(image).astype(np.uint8)
                
            image_vqgan = preprocessor_finetune(image=img_array)["image"]
            # TODO: check this?
            image_vqgan2 = (image_vqgan/255).astype(np.float32)
            image_vqgan = (image_vqgan/127.5 - 1.0).astype(np.float32)
            # print(image_vqgan2.max(), image_vqgan2.min())
            processed_image.append(image_vqgan2)
            image_vqgan_list.append(image_vqgan[None])
            # image_vqgan = torch.tensor(image_vqgan.transpose(2,0,1)[None]).to(dtype=torch.float32)
        image_vqgan_list = np.concatenate(image_vqgan_list, axis=0)
        return image_vqgan_list, processed_image


    def _read_process_vision(self, images):
 
        vision, processed_image = self._process_frame(images, 256)
        
        B = 1
        encodings = []
        for i in range(0, len(vision), 1):
            v = vision[i:i + B]
            if len(v) % B == 0:
                n_pad = 0
            else:
                n_pad = B - len(v) % B
            v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
            enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
            # print("enc", enc)
            enc = enc[n_pad:]
            for t in range(len(enc)):
                encodings.extend(enc[t].reshape(-1).tolist())

        return encodings, processed_image



    def construct_input(self, prompts, raw_pixel=None):
        if raw_pixel is None:
            for i, prompt in enumerate(prompts):
                vision, processed_image = self._read_process_vision(prompt['image'])
                tokens, vm = [], []
                tokens.extend(vision)
                vm.extend([True] * len(vision))
                tokens.extend([8193])
                vm.extend([True] * len([8193]))
        else:
            tokens = raw_pixel
            # tokens = ["2509", "6002", "5504", "7649", "1344", "7485", "4910", "2758", "7661", "729", "11", "7343", "3043", "1211", "2838", "13", "1016", "6656", "4018", "6240", "7742", "2823", "5498", "2786", "2005", "462", "7348", "29", "3056", "7958", "5349", "8052", "2509", "6403", "643", "7370", "5615", "204", "4076", "1646", "5343", "2463", "326", "7147", "462", "1834", "427", "6421", "6266", "752", "646", "5235", "6700", "4569", "3762", "1644", "4721", "2009", "4521", "2903", "7965", "750", "8044", "7370", "3844", "1683", "6162", "6152", "499", "6317", "5207", "274", "555", "454", "8163", "1215", "4138", "2432", "556", "7181", "4321", "6367", "3918", "3286", "1908", "2323", "4125", "3336", "3313", "424", "7652", "7911", "1772", "4721", "6964", "7161", "8118", "1437", "1296", "2596", "12", "3871", "5205", "6481", "4912", "7723", "6751", "7262", "3695", "6606", "1865", "4357", "4472", "352", "7413", "5215", "7367", "610", "3919", "847", "6216", "2025", "7761", "1921", "6116", "1287", "5527", "7527", "6860", "3932", "1183", "251", "3268", "5500", "2669", "3621", "5712", "7283", "562", "7509", "4519", "1592", "1489", "5000", "6071", "8118", "4207", "2821", "4507", "5666", "7024", "2430", "3507", "5684", "7601", "7739", "503", "4113", "2137", "2437", "1865", "8085", "3342", "4727", "3547", "4130", "4685", "5478", "5389", "4723", "6350", "7089", "3173", "7415", "2076", "6797", "1822", "5738", "5030", "2037", "35", "7644", "292", "7833", "459", "7094", "1834", "407", "63", "4259", "3123", "2789", "4042", "5249", "1300", "4311", "3449", "3214", "4453", "7893", "2450", "814", "2090", "4986", "6968", "1820", "4040", "1551", "1228", "1840", "5319", "6426", "7834", "4926", "5656", "4170", "4431", "931", "6", "6174", "4750", "428", "859", "3967", "7689", "2525", "7106", "7933", "633", "513", "7472", "2847", "6139", "5187", "2143", "5903", "647", "6262", "7892", "2971", "6765", "620", "4724", "7613", "6559", "6228", "6826", "3022", "3975", "2437", "3388", "3679", "4448", "1100", "5442", "4646"]
            tokens = [int(token) for token in tokens]
            tokens.append(8193)
            processed_image = None
        # print("tokens",tokens)
        return {
            'input_ids': np.expand_dims(tokens, axis=0),
            'processed_image': processed_image,
        }
             

    def _load_model(self):
        if self.FLAGS.load_llama_config != '':
            llama_config = VideoLLaMAConfig.load_config(self.FLAGS.load_llama_config)
            updates = VideoLLaMAConfig(**self.FLAGS.llama)
            llama_config.update(dict(
                remat_block=updates.remat_block,
                remat_attention=updates.remat_attention,
                remat_mlp=updates.remat_mlp,
                scan_attention=updates.scan_attention,
                scan_mlp=updates.scan_mlp,
                scan_query_chunk_size=updates.scan_query_chunk_size,
                scan_key_chunk_size=updates.scan_key_chunk_size,
                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
                scan_layers=updates.scan_layers,
                param_scan_axis=updates.param_scan_axis,
            ))
        else:
            llama_config = VideoLLaMAConfig(**self.FLAGS.llama)




        if self.FLAGS.update_llama_config != '':
            llama_config.update(dict(eval(self.FLAGS.update_llama_config)))

        llama_config.update(dict(
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        ))
        llama_config.update(dict(mesh_dim=self.FLAGS.mesh_dim))
        self.config = llama_config

        with jax.default_device(jax.devices("cpu")[0]):
            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
                    self.FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
            )
            self.model = FlaxVideoLLaMAForCausalLM(
                llama_config, 
                input_shape=(512, 8192), 
                seed=self.FLAGS.seed, 
                _do_init=False,
                dtype=get_float_dtype_by_name(self.FLAGS.dtype),
            )

        
            self.model_ps = match_partition_rules(
                VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
            )
            shard_fns, _ = make_shard_and_gather_fns(
                self.model_ps, get_float_dtype_by_name(self.FLAGS.dtype)
            )

            with self.mesh:
                self.params = tree_apply(shard_fns, self.params)

    @cached_property
    def _forward_generate_gt(self):
        def fn(params, rng, batch, n_tokens):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)
            
            action_input_ids = batch['input_ids']
            action_vision_masks = batch['vision_masks']
            action_attn_mask = batch['attention_mask']
            action_delta_masks = batch['delta_masks']
            # delta_output = gt_delta
            delta_output = None

            self.model.config.sample_mode='action'
            action_logits, gripper_logits = self.model.module.apply(
                params,
                action_input_ids,
                attention_mask=action_attn_mask,
                vision_masks=action_vision_masks,
                delta_masks=action_delta_masks,
                deterministic=True,
                # rngs={'dropout': rng_generator()},
            ).logits

            action_logits = action_logits[0,-1,:]
            gripper_value = jnp.argmax(gripper_logits[0,-1,:])
            

            return delta_output, action_logits, gripper_value, rng_generator()
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS()),
            out_shardings=(PS(), PS(), PS(), PS()),
            static_argnums=(3,)
        )

    @cached_property
    def _forward_generate(self):
        def fn(params, rng, batch, n_tokens):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)
            

            self.model.config.sample_mode='delta'
            output = self.model.generate_vision(
                batch['input_ids'],
                vision_masks=batch['vision_masks'],
                attention_mask=batch['attention_mask'],
                delta_masks=batch['delta_masks'],
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=n_tokens,
                    min_new_tokens=n_tokens,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            ).sequences

            delta_output= output[:,batch['input_ids'].shape[1]:]


            action_vision_masks = jnp.concatenate([jnp.expand_dims(batch['vision_masks'][0], axis=0),jnp.zeros(jnp.array(delta_output).shape, dtype=bool)], axis=1)
            action_attn_mask = jnp.concatenate([jnp.expand_dims(batch['attention_mask'][0], axis=0), jnp.ones(jnp.array(delta_output).shape, dtype= batch['attention_mask'].dtype)], axis=1)
            action_delta_masks = jnp.concatenate([jnp.expand_dims(batch['delta_masks'][0], axis=0),jnp.ones(jnp.array(delta_output).shape, dtype=bool)], axis=1)
            
            action_input_ids = output
        

            self.model.config.sample_mode='action'
            action_logits, gripper_logits = self.model.module.apply(
                params,
                action_input_ids,
                attention_mask=action_attn_mask,
                vision_masks=action_vision_masks,
                delta_masks=action_delta_masks,
                deterministic=True,
                # rngs={'dropout': rng_generator()},
            ).logits

            action_logits = action_logits[0,-1,:]
            gripper_value = jnp.argmax(gripper_logits[0,-1,:])
            

            return delta_output, action_logits, gripper_value, rng_generator()
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS()),
            out_shardings=(PS(), PS(), PS(), PS()),
            static_argnums=(3,)
        )

    
    def generate_video_pred(self, prompts, images, max_input_length, gt_delta=None):
        
        sharded_rng = next_rng()
        # print("image shape", images.shape)

        # ***********

        inputs = self.prefix_tokenizer(
            prompts,
            padding='max_length',
            truncation=True,
            max_length=max_input_length,
            return_tensors='np'
        )
        # print(inputs.input_ids.shape)
        # prefix_for_gen = ["</vision>  ASSISTANT:"] * len(prompts)
        prefix_for_gen = ["</vision><delta>"] * len(prompts)
        # prefix_for_gen = ["</vision><delta>"] * len(prompts)
        inputs_for_gen = self.prefix_tokenizer(
            prefix_for_gen,
            return_tensors='np'
        )

        if gt_delta is None:
            batch = dict(
                input_ids=np.concatenate([inputs.input_ids, images, inputs_for_gen.input_ids], axis=1),
                attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype), inputs_for_gen.attention_mask], axis=1),
                vision_masks=np.concatenate([
                    np.zeros(inputs.input_ids.shape, dtype=bool),
                    np.ones(images.shape, dtype=bool),
                    np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
                ], axis=1),
                delta_masks=np.concatenate([
                    np.zeros(inputs.input_ids.shape, dtype=bool),
                    np.zeros(images.shape, dtype=bool),
                    np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
                ], axis=1),
            )
        else:
            gt_delta = np.array(gt_delta)
            gt_delta = np.expand_dims(gt_delta, axis=0)
            print("gt delta shape", gt_delta.shape)
            batch = dict(
                input_ids=np.concatenate([inputs.input_ids, images, inputs_for_gen.input_ids, gt_delta], axis=1),
                attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype), inputs_for_gen.attention_mask, np.ones(gt_delta.shape, dtype=inputs.attention_mask.dtype)], axis=1),
                vision_masks=np.concatenate([
                    np.zeros(inputs.input_ids.shape, dtype=bool),
                    np.ones(images.shape, dtype=bool),
                    np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
                    np.zeros(gt_delta.shape, dtype=bool),
                ], axis=1),
                delta_masks=np.concatenate([
                    np.zeros(inputs.input_ids.shape, dtype=bool),
                    np.zeros(images.shape, dtype=bool),
                    np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
                    np.ones(gt_delta.shape, dtype=bool),
                ], axis=1),
            )
            
        with self.mesh:
            print("gt delta is not None", gt_delta)
            if gt_delta is not None:
                delta_output, action_output, gripper_value, sharded_rng = self._forward_generate_gt(
                    self.params, sharded_rng, batch, 
                    self.FLAGS.tokens_per_delta
                )
                delta_output = gt_delta
            else:
                delta_output, action_output, gripper_value, sharded_rng = self._forward_generate(
                    self.params, sharded_rng, batch, 
                    self.FLAGS.tokens_per_delta
                )
                delta_output = jax.device_get(delta_output)
            action_output = jax.device_get(action_output)
            # print("original action", action_output)
            # norm_mean = np.array([0.00044061258321564776,-3.489882665825449e-05,0.0003598753898464112,-0.000219091608573299,-0.0014239931609872339,-0.00021917253424783405])
            # norm_std = np.array([0.011353923389502165,0.014732306464712396,0.011232106825340872,0.025059465771357696,0.0278502974403774,0.06337970196981181])
            # action_output = action_output * norm_std + norm_mean
            
            gripper_value = jax.device_get(gripper_value)
            action_output = np.append(action_output, float(gripper_value))
            # **************
            print("delta and action", delta_output, action_output)
            
        return None, action_output, delta_output

    def __call__(self, prompts, raw_pixel=None, gt_delta=None):
        batch = self.construct_input(prompts, raw_pixel)
        # prompts = []
        # print("question", prompts[0]['question'])
        text_prompt = f"<s> You are a helpful assistant. USER: What action should the robot take to `{prompts[0]['question']}` ASSISTANT: <vision>"
        # text_prompt = f"<s> You are a helpful assistant. USER: What action should the robot take to `Put the stuffed toy in the pot` ASSISTANT: <vision>"
        user_action_token_num, action_output, delta_output = self.generate_video_pred(prompts=[text_prompt], images=batch['input_ids'], max_input_length=128, gt_delta=gt_delta)
        print("action output", action_output)
        return user_action_token_num, delta_output, action_output
        # return vision_output
        