import torch
from logzero import logger

from base import BaseVQA, work
import os
import time
from qwen2_5_vl_dy import Qwen2_5_VLForConditionalGeneration
from transformers import AutoProcessor
from qwen_vl_dy_utils import process_vision_info
from constant import caption_prompt, plan_prompt, score_prompt, action_prompt, trigger_prompt
import torch
from parse import Parser, extract_answer
import torch.nn.functional as F
from decord import VideoReader, cpu, gpu
from utils import get_seq, parse_json_loose, process_video_with_bbox
from typing import Any, Dict, List, Optional


MIN_PIXELS = 224*224
MAX_PIXELS = 224*224
MIN_FRAMES = 2
MAX_FRAMES = 180
FPS = 1


CKPT_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
video_path = "your video path"

class StreamAgentVQA(BaseVQA):

    def video_qa(self, question, max_new_tokens=1024):
        pred_answer = self.qa_model.question_answering(question, max_new_tokens=max_new_tokens)
        return pred_answer

    @torch.inference_mode()
    def analyze_a_video(self, video_sample):
        video = self.load_video(video_path)
        if not isinstance(video, torch.Tensor):
            video_tensor = torch.from_numpy(video)
        else:
            video_tensor = video

        self.qa_model.clear_cache()
        self.qa_model.encode_init_prompt()
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        CKPT_PATH,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map="auto",
        )

        processor = AutoProcessor.from_pretrained(CKPT_PATH)
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": video_path,
                        "min_pixels": MIN_PIXELS,
                        "max_pixels": MAX_PIXELS,
                        "min_frames": MIN_FRAMES,
                        "max_frames": MAX_FRAMES,
                        "fps": 1,
                        # "nframes": NFRAMES,
                    },
                    {
                        "type": "text",
                        "text": "",
                    }
                ]
            }
        ]
        video_info_all = process_vision_info([messages], return_video_kwargs=True)
        video_seq_list = get_seq(video_info_all)

        memory = ""

        

        parser = Parser()

        do_action = False
        for video_info in video_seq_list:
            if do_action:
                if action == 'Crop and Zoom In':
                    try:
                        video_info[1] = process_video_with_bbox(video_info[1], bbox)
                    except:
                        pass
                elif action == 'Object Traction':
                    try:
                        video_info[1] = draw_bboxes_on_video_tensor(video_info[1], object_bbox)
                    except:
                        pass

            with torch.no_grad():
                caption_prompt_format = caption_prompt.format(memory=memory)
                caption = get_response(caption_prompt_format, video_info, self.model, self.processor)
                memory = caption
                stream = None
                if stream:
                    plan_prompt_format = plan_prompt.format(question=question, memory=memory)
                    plan = get_response(plan_prompt_format, video_info, self.model, self.processor)
                    trigger_prompt_format = trigger_prompt.format(question=question, memory=memory, plan=plan)
                    trigger = get_response(trigger_prompt_format, video_info, self.model, self.processor)
                    decision = extract_answer(trigger)
                    if decision == 'yes':
                        print("now we need to answer the question")
                        break
                    elif decision == 'no':
                        print("continue")
                    action_prompt_format = action_prompt.format(question=question, memory=memory, plan=plan)
                    action_planner = get_response(action_prompt_format, video_info, self.model, self.processor)
                    try:
                        action = parse_json_loose(action_planner)
                        if action == 'No Tool':
                            do_action = False
                        else:
                            try:
                                action = action['Action']['tool_name']
                                bbox = action['Action']['bbox']
                                do_action = True
                            except:
                                do_action = False
                    except:
                        do_action = 'False'
                    
        qa_results = self.video_open_qa(question+agent_prompt, max_new_tokens=256)


if __name__ == "__main__":
    work(StreamAgentVQA)



