import torch
from logzero import logger


class Abstract_AAAI:
    processor = None
    kv_cache = None

    def __init__(self, processor, n_frame_tokens, init_prompt_ids, n_local, topk, chunk_size):
        self.processor = processor
        self.n_frame_tokens = n_frame_tokens
        self.init_prompt_ids = init_prompt_ids
        self.n_local = n_local
        self.topk = topk
        self.chunk_size = chunk_size

    def clear_cache(self):
        self.kv_cache = None
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    @torch.inference_mode()
    def encode_init_prompt(self):
        if not isinstance(self.init_prompt_ids, torch.Tensor):
            self.init_prompt_ids = torch.as_tensor([self.init_prompt_ids], device=self.device)
        print(f'encode_init_prompt begin !')
        output = self.model(input_ids=self.init_prompt_ids, use_cache=True, return_dict=True)
        self.kv_cache = output.past_key_values
        print(f'encode_init_prompt end !')
        print(f'qqqqqqqqqqqqqqq type(kv_cache[0]): {type(self.kv_cache[0])} qqqqqqqqqqqqqqqqqqqq')

    def _get_video_features(self, pixel_values_videos):
        pass

    def _encode_video_chunk(self, video_chunk):
        frames, height, width, channels = video_chunk.shape
        video_chunk = video_chunk.view(frames, channels, height, width)
        print(f'video_chunk shape: {video_chunk.shape}')
        video_inputs = self.processor(text="", images=None, videos=video_chunk, return_tensors="pt")
        # video_inputs = self.processor.image_processor(images=None, videos=video_chunk, return_tensors="pt")
        print(f'video_inputs: {video_inputs}')
        pixel_values_videos = video_inputs["pixel_values_videos"].to(self.device, self.dtype)  # (1, Nv, 3, H, W)
        print(f'pixel_values_videos shape: {pixel_values_videos.shape}')
        print(f'pixel_values_videos device: {pixel_values_videos.device}')
        video_grid_thw = video_inputs["video_grid_thw"]
        video_features = self._get_video_features(pixel_values_videos, video_grid_thw)  # (1, Nv*196, D)
        print(f'------------------Finished Encode------------------')
        assert self.n_local >= video_features.shape[1], f'n_local: {self.n_local}, video_features: {video_features.shape[1]}'

        print(f'video_features shape: {video_features.shape}')
        video_features = video_features.unsqueeze(0)
        output = self.model(inputs_embeds=video_features, past_key_values=self.kv_cache, use_cache=True, return_dict=True)
        print(f'------------------Finished Decode------------------')
        self.kv_cache = output.past_key_values

    @torch.inference_mode()
    def encode_video(self, video, encode_chunk_size=32):  # video: (Nv, H, W, 3)
        # encode chunk by chunk
        num_frames = video.shape[0]
        num_chunks = num_frames // encode_chunk_size
        print(f'/////////////len(video): {len(video)}///////////////')
        for chunk_idx in range(num_chunks):
            start_idx = chunk_idx * encode_chunk_size
            end_idx = start_idx + encode_chunk_size
            chunk_video = video[start_idx:end_idx]
            self._encode_video_chunk(chunk_video)
            logger.debug(f'KV-Cache RAM usage: {self.calc_memory_usage() / (1024**3):.1f} GB')

        # Handle remaining frames
        remaining_frames = num_frames % encode_chunk_size
        if remaining_frames > 0:
            start_idx = num_chunks * encode_chunk_size
            end_idx = start_idx + remaining_frames
            remaining_video = video[start_idx:end_idx]
            print(f'~~~~~~~~~~~~~~~~~~~remaining_video shape: {remaining_video.shape}~~~~~~~~~~~~~~~~')
            self._encode_video_chunk(remaining_video)
        
        logger.debug(f'KV-Cache RAM usage: {self.calc_memory_usage() / (1024**3):.1f} GB')

    @torch.inference_mode()
    def question_answering(self, input_text, max_new_tokens=128):
        pass

    def calc_memory_usage(self):
        n_layers = len(self.kv_cache)
        memory = n_layers * self.kv_cache[0].calculate_cpu_memory()
        return memory
