from absl.app import run
import math
from tqdm import tqdm
from PIL import Image
import decord
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from transformers import GenerationConfig
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, open_file
)
# from lwm.delta_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.delta_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM 
from lwm.vqgan import VQGAN
import albumentations
import jax.numpy as jnp
import cv2
import sys
import torch
from einops import rearrange, repeat, pack, unpack
sys.path.append('/data/anon/World-Model/Phenaki')
# sys.path.append('/data/anon/byeongguk-world-model/World-Model/data/Phenaki')
from phenaki_pytorch import CViViT_single_linear_nsvq_2, CViViT_single_linear_nsvq_3

class DeltaSampler:
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
        self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
        self.prefix_tokenizer = VideoLLaMAConfig.get_tokenizer(
            FLAGS.tokenizer, truncation_side='left', padding_side='left'
        )
        self.tokenizer = VideoLLaMAConfig.get_tokenizer(FLAGS.tokenizer)
        # self.tokens_per_delta = 64
        self.min_buffer_size = 256
        self.sharded_rng = next_rng()
        self._load_model()


    @property
    def block_size(self):
        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
    
    @property
    def data_dim(self):
        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']

    def _process_frame(self, images, size):
        preprocessor_finetune = albumentations.Compose([
                albumentations.LongestMaxSize(max_size=256),  # Resize the longest side to 256
                # albumentations.PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=(0, 0, 0))
                albumentations.Resize(256, 256), 
            ])
        image_vqgan_list = []
        processed_image = []
        for image in images:
            img_array = np.array(image).astype(np.uint8)
                
            image_vqgan = preprocessor_finetune(image=img_array)["image"]
            # TODO: check this?
            image_vqgan2 = (image_vqgan/255).astype(np.float32)
            image_vqgan = (image_vqgan/127.5 - 1.0).astype(np.float32)
            # print(image_vqgan2.max(), image_vqgan2.min())
            processed_image.append(image_vqgan2)
            image_vqgan_list.append(image_vqgan[None])
            # image_vqgan = torch.tensor(image_vqgan.transpose(2,0,1)[None]).to(dtype=torch.float32)
        image_vqgan_list = np.concatenate(image_vqgan_list, axis=0)
        return image_vqgan_list, processed_image


    def _read_process_vision(self, images):

        # path = '/mnt/sda/anon/World-As-Code/llava/playground/data/lang_table_separate_100/1.jpg'
        # f = open_file(path, 'rb')
        # if path.endswith('.png') or path.endswith('.jpg'):
        #     image = Image.open(f)
 
        vision, processed_image = self._process_frame(images, 256)
        
        B = 1
        encodings = []
        for i in range(0, len(vision), 1):
            v = vision[i:i + B]
            if len(v) % B == 0:
                n_pad = 0
            else:
                n_pad = B - len(v) % B
            v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
            enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
            # print("enc", enc)
            enc = enc[n_pad:]
            for t in range(len(enc)):
                encodings.extend(enc[t].reshape(-1).tolist())
        return encodings, processed_image



    def construct_input(self, prompts):
        for i, prompt in enumerate(prompts):
            vision, processed_image = self._read_process_vision(prompt['image'])
            tokens, vm = [], []
            tokens.extend(vision)
            vm.extend([True] * len(vision))
            tokens.extend([8193])
            vm.extend([True] * len([8193]))
        return {
            'input_ids': np.expand_dims(tokens, axis=0),
            'processed_image': processed_image,
        }
             

    def _load_model(self):
        if self.FLAGS.load_llama_config != '':
            llama_config = VideoLLaMAConfig.load_config(self.FLAGS.load_llama_config)
            updates = VideoLLaMAConfig(**self.FLAGS.llama)
            llama_config.update(dict(
                remat_block=updates.remat_block,
                remat_attention=updates.remat_attention,
                remat_mlp=updates.remat_mlp,
                scan_attention=updates.scan_attention,
                scan_mlp=updates.scan_mlp,
                scan_query_chunk_size=updates.scan_query_chunk_size,
                scan_key_chunk_size=updates.scan_key_chunk_size,
                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
                scan_layers=updates.scan_layers,
                param_scan_axis=updates.param_scan_axis,
            ))
        else:
            llama_config = VideoLLaMAConfig(**self.FLAGS.llama)

        # self.cvivit = CViViT_single_linear_nsvq_3(
        #     dim = 1024,
        #     quant_dim=32,
        #     codebook_size = self.FLAGS.codebook_size,
        #     image_size = 256,
        #     patch_size = 32,
        #     temporal_patch_size = 2,
        #     spatial_depth = self.FLAGS.spatial_depth,
        #     temporal_depth = self.FLAGS.temporal_depth,
        #     dim_head = 64,
        #     heads = 16,
        #     lookup_free_quantization=False,
        #     use_vgg_and_gan=False,
        # ).cuda()
        # # self.cvivit.load('/data/anon/Phenaki/ckpt/whole_delta_32_512_quant_dim_32_window_5_patch32_dim512_depth_2_bs64_vq_additive_no_temporal_patch_no_vgg_no_grad_recon100_1e4/vae.48000.pt')
        # # self.cvivit.load('/data/anon/byeongguk-world-model/World-Model/data/vae.48000.pt')
        # self.cvivit.load(self.FLAGS.phenaki_checkpoint)


        if self.FLAGS.update_llama_config != '':
            llama_config.update(dict(eval(self.FLAGS.update_llama_config)))

        llama_config.update(dict(
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        ))
        llama_config.update(dict(mesh_dim=self.FLAGS.mesh_dim))
        self.config = llama_config

        with jax.default_device(jax.devices("cpu")[0]):
            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
                    self.FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
            )
            self.model = FlaxVideoLLaMAForCausalLM(
                llama_config, 
                input_shape=(512, 8192), 
                seed=self.FLAGS.seed, 
                _do_init=False,
                dtype=get_float_dtype_by_name(self.FLAGS.dtype),
            )

        
            self.model_ps = match_partition_rules(
                VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
            )
            shard_fns, _ = make_shard_and_gather_fns(
                self.model_ps, get_float_dtype_by_name(self.FLAGS.dtype)
            )

            with self.mesh:
                self.params = tree_apply(shard_fns, self.params)

    @cached_property
    def _forward_generate(self):
        def fn(params, rng, batch, n_tokens):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)
            output = self.model.generate_vision(
                batch['input_ids'],
                attention_mask=batch['attention_mask'],
                vision_masks=batch['vision_masks'],
                delta_masks=batch['delta_masks'],
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=n_tokens,
                    min_new_tokens=n_tokens,
                    pad_token_id=self.tokenizer.pad_token_id,
                    sample_mode='delta'
                )
            ).sequences 

            vision_output = output[:, batch['input_ids'].shape[1]:]
            text_output = None
               
            return vision_output, text_output, rng_generator()
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS()),
            out_shardings=(PS(), PS(), PS()),
            static_argnums=(3,)
        )

    
    def generate_video_pred(self, prompts, images, max_input_length, processed_image):
        
        sharded_rng = next_rng()
        # print("image shape", images.shape)

        # ***********
        image = images
        # images = np.concatenate([images, images], axis=0)
        # uncond_prompts = ["<s><vision>"] * len(prompts)
        # prompts = prompts + uncond_prompts
        # print("prompts", prompts)
        inputs = self.prefix_tokenizer(
            prompts,
            padding='max_length',
            truncation=True,
            max_length=max_input_length,
            return_tensors='np'
        )
        # print(inputs.input_ids.shape)
        prefix_for_gen = ["</vision><delta>"] * len(prompts)
        inputs_for_gen = self.prefix_tokenizer(
            prefix_for_gen,
            return_tensors='np'
        )
        # print("inputs for gen shape", inputs_for_gen.input_ids.shape)


        batch = dict(
            input_ids=np.concatenate([inputs.input_ids, images, inputs_for_gen.input_ids], axis=1),
            attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype), inputs_for_gen.attention_mask], axis=1),
            vision_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.ones(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool)
            ], axis=1),
            delta_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.zeros(images.shape, dtype=bool),
                np.zeros(inputs_for_gen.input_ids.shape, dtype=bool),
            ], axis=1),
        )
        # print("batch", batch["input_ids"])

        with self.mesh:
            vision_output, text_output, sharded_rng = self._forward_generate(
                self.params, sharded_rng, batch, 
                self.FLAGS.tokens_per_delta
            )
            text_output = jax.device_get(text_output)
            vision_output = jax.device_get(vision_output)
            # **************
            print(vision_output)
            input_image = rearrange(torch.tensor(processed_image[0]), 'h w c -> 1 c h w')
        
            input_image = torch.cat([input_image.unsqueeze(2), input_image.unsqueeze(2)],dim=2)
            # recon_image = self.cvivit.inference(user_action_token_num=torch.tensor(vision_output[0]).unsqueeze(0).cuda(), torch.tensor(processed_image[0]).unsqueeze(0).cuda()).squeeze(2)
            user_action_token_num = vision_output[0][0]
            print("user_action_token_num",user_action_token_num)
            # recon_image = self.cvivit.inference(input_image.cuda(), user_action_token_num=torch.tensor(vision_output[0]).unsqueeze(0).cuda())
            # recon_image = self.cvivit.inference(input_image.cuda(), user_action_token_num=user_action_token_num)
            # print("recon image", recon_image.shape)
            # video = torch.clamp(recon_image, 0.0, 1.0)
            # video *= 255.0
            # video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
            # print("video", video[0].shape)
            # video = [Image.fromarray(image) for image in video]
            # video[0].save(f'image2.jpg')

        # vision_output = video[0]
        vision_output = None
        # TODO: current image, next image condition -> action 
        # example-vision-text-train_7b_pad_fp32_world_model_action_chunking
        # TODO: nsvq cvivit -> inference (decode from codebook) 
        # print("output shape", vision_output)
        # print(vision_output.shape)


        
        # vision_output = vision_output.reshape(1, self.tokens_per_delta)
        # vision_output = vision_output[:,:-1].reshape(-1, 16, 16)
        # vision_output = self.vqgan.decode(vision_output)
        # vision_output = ((jax.device_get(vision_output) + 1) * 127.5).astype(np.uint8)
        # img = Image.fromarray(vision_output.squeeze(0))
        # img.save('../lwm/lang_table_test_world_model_whole.jpg')

        return user_action_token_num, text_output, vision_output

    def __call__(self, prompts):
        batch = self.construct_input(prompts)
        text_prompt = f"<s>You are a helpful assistant. USER: {prompts[0]['question']} <vision>"
        user_action_token_num, text_output, vision_output = self.generate_video_pred(prompts=[text_prompt], images=batch['input_ids'], max_input_length=128, processed_image=batch["processed_image"])
        output_text = []
        if text_output is not None:
            print("text_output", list(self.tokenizer.batch_decode(text_output)))
            for text in list(self.tokenizer.batch_decode(text_output, skip_special_tokens=True)):
                if self.tokenizer.eos_token in text:
                    text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
                output_text.append(text)
        else: 
            output_text = ["Do not move. 10  10"]
        return user_action_token_num, vision_output, output_text
        # return vision_output
        