import json
import os
from typing import Dict, List, Any
from dataclasses import dataclass
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import time
import json
import os
from pathlib import Path
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoModelForZeroShotObjectDetection
from qwen_vl_utils import process_vision_info, smart_resize
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from PIL import Image
import os  
import re  
import random
import cv2
import numpy as np
import warnings
warnings.filterwarnings("ignore")
from vllm import LLM, SamplingParams
vllm_available = True

from utils import (
    extract_fields_2,
    get_img_crop,
    bbox_transform,
)

random.seed(42) 

class AgentVLMVLLM:
    """
    A class that encapsulates the multimodal interleaved reasoning process
    and manages multi-turn dialogues.
    """
    
    def __init__(self, model=None, 
                 processor=None,
                 temp_dir="./agent_crops_tmp",
                 device="cuda",
                 min_pixels=3136,
                 max_pixels=802816,
                 ):
        self.temp_dir = temp_dir
        os.makedirs(self.temp_dir, exist_ok=True)
        self.sample_id = 0
        self.cnt_tmp_img = 0
        
        self.prompt = """\nYou need to first think about the reasoning process in your mind, and then provide the answer. When thinking, you should call the "crop" tool (format: {"bbox_2d": [x1, y1, x2, y2]}) to focus on the key areas in the image. The reasoning process and the answer are included in the <think> </think> and <answer> </answer> tags respectively."""
        # Constants
        self.delta_img = '<|vision_start|><|image_pad|><|vision_end|>'
        self.img_crop_path = os.path.join(self.temp_dir, "{device}_sample_{sample}_crop_{cnt}.png")
        
        # Stopping criteria for generation
        # self.stop_pattern = r'```json\s*(.*?)\s*```'
        self.stop_pattern = r'\{.*?\}'

        # Initialize models
        self.model = model
        self.processor = processor
        self.sampling_params = SamplingParams(
            # temperature=1,
            max_tokens=1024,
            stop_token_ids=[],
            # stop=[".0}",".1}",".2}",".3}",".4}",".5}",".6}",".7}",".8}",".9}"],
            stop=["]}","]\n}","] }"],
        )
        self.device = device
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels

    
    def set_model(self, model, processor, sampling_params=None):
        """Set the model and processor for the agent."""
        self.model = model
        self.processor = processor
        if sampling_params is not None:
            self.sampling_params = sampling_params
        for filename in os.listdir(self.temp_dir):
            file_path = os.path.join(self.temp_dir, filename)
            try:
                if os.path.isfile(file_path):
                    os.remove(file_path)
            except Exception as e:
                print(f"Error deleting file {file_path}: {e}")
                continue

    
    def crop_and_save_region(self, img, raw_img, bbox, scale_factor=None):
        """
        Crop and save a region from the image.
        
        Args:
            img: Processed image for the VL model
            raw_img: Original PIL image
            bbox: Bounding box coordinates [x1, y1, x2, y2]
            scale_factor: Optional scaling factor
            
        Returns:
            Actual scale factor used
        """
        if not bbox:
            return None
        
        bbox = bbox_transform(bbox, img.size[0], img.size[1], raw_img.size[0], raw_img.size[1])
        
        # Get original image dimensions
        orig_width, orig_height = raw_img.size
        orig_area = orig_width * orig_height
        
        # Crop the region
        cropped_img = get_img_crop(raw_img, bbox)
        if cropped_img is None:
            print(f"## DEBUG: Invalid bbox: {bbox}")
            return None
        
        actual_scale = 1.0
        crop_width, crop_height = cropped_img.size
        bbox_area = crop_width * crop_height
        r = bbox_area / orig_area  
        if r < 0.125:
            actual_scale = 2.0
        elif r >= 0.5:
            actual_scale = 1.0
        else:
            actual_scale = 2.0 - (r - 0.125) / 0.375  
        crop_height = int(actual_scale * crop_height)  
        crop_width = int(actual_scale * crop_width)    

        h1, w1 = smart_resize(
            crop_height,
            crop_width,
            factor=28,
            min_pixels=4*28*28,
            max_pixels=1024*28*28,
        )
        cropped_img = cropped_img.resize((w1, h1), Image.LANCZOS)
        
        self.cnt_tmp_img += 1
        device = getattr(self.model, 'device', self.device)  
        output_path = self.img_crop_path.format(device=str(device).replace(':',"_"),sample=self.sample_id, cnt=self.cnt_tmp_img)
        cropped_img.save(output_path)
        print(f"## DEBUG: Saved cropped image to {output_path} (scale={actual_scale})")
        
        return actual_scale
    
    
    def process(self, image_path, question, max_iterations=12):
        """
        Process an image with a question using multimodal interleaved reasoning.
        
        Args:
            
        Returns:
            Tuple of (full_response, complete_text, img_crop_list, response_list)
        """
        self.cnt_tmp_img = 0
        self.sample_id += 1
        
        img_crop_list = []
        response_list = []
        
        # Load original image 
        raw_img = Image.open(image_path).convert("RGB")
        width, height = raw_img.size
        resized_height, resized_width = smart_resize(
            height,
            width,
            factor=28,
            min_pixels=self.min_pixels,
            max_pixels=self.max_pixels,  #  
        )
        real_input_img = raw_img.resize((resized_width, resized_height))
        question += self.prompt
        # Initialize messages
        text_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_path},
                    {"type": "text", "text": question},
                ],
            }
        ]  
        
        img_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_path},
                ],
            }
        ]
        
        full_response = ""
        # prompt_text = ""
        crop_flag = False
        
        prompt_ids_len = None

        iterations = 0
        while iterations < max_iterations:
            iterations += 1
            
            # Add cropped image to messages if available
            if self.cnt_tmp_img > 0 and crop_flag:
                crop_flag = False
                device = getattr(self.model, 'device', self.device)  
                add_img_path = self.img_crop_path.format(device=str(device).replace(':',"_"), sample=self.sample_id, cnt=self.cnt_tmp_img)
                add_img = {
                    "type": "image",
                    "image": add_img_path
                }
                img_crop_list.append(add_img_path)
                img_messages[0]['content'].append(add_img)

                
                text_messages[-1]['content'].append(add_img)
            
            text = self.processor.apply_chat_template(
                text_messages, tokenize=False, add_generation_prompt=True if len(text_messages)==1 else False
            )

            if text.endswith("<|im_end|>\n"):
                text = text[:-len("<|im_end|>\n")] + "\n"
            
            # Process inputs
            image_inputs, video_inputs = process_vision_info(img_messages)
        
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                return_tensors="pt",
                padding=True,
                padding_side="left",
                add_special_tokens=False,
            )
            if prompt_ids_len is None:
                prompt_ids_len = len(inputs['input_ids'][0])

            mm_data = {}
            if image_inputs is not None:
                mm_data["image"] = image_inputs
            llm_input = {
                "prompt": text,
                "multi_modal_data": mm_data,
            }
            generated_ids = inputs['input_ids'][0].cpu().tolist()

            # # Generate next part of response
            # if len(inputs['input_ids'][0]) >= 4096:
            #     print(f"##DEBUG:  Exceeded maximum input length")
            #     break
            
            outputs = self.model.generate([llm_input], sampling_params=self.sampling_params)
            generated_ids = generated_ids + list(outputs[0].outputs[0].token_ids)
            # Extract and decode the newly generated content
            generated_ids_trimmed = outputs[0].outputs[0].token_ids
            generated_text = outputs[0].outputs[0].text
            # print("##DEBUG: Generated text:", generated_text)
            # Check if we have a final answer
            if generated_ids_trimmed[-1] == self.processor.tokenizer.eos_token_id or "</answer>" in generated_text:
                # Add to the full response
                full_response += generated_text
                response_list.append(generated_text)
                if len(text_messages) == 1:
                    text_messages.append(
                        {
                            "role": "assistant",
                            "content": [
                                {"type": "text", "text": response_list[-1]},
                            ],
                        },
                    )
                else:
                    text_messages[-1]['content'].append(
                        {"type": "text", "text": response_list[-1]},
                    )
                break

            if "bbox_2d" in generated_text:
                # generated_text += ".0}\n"
                generated_text += "]}\n"

            # Extract operation from JSON
            op = extract_fields_2(generated_text)
            print(f"##DEBUG: Extracted operation: {op}")
            if op:
                try:
                    actual_scale = self.crop_and_save_region(real_input_img, raw_img, op["bbox_2d"])
                    # op['scale'] = actual_scale
                    # Update the generated text with actual scale
                    matches = re.findall(self.stop_pattern, generated_text, re.DOTALL)
                    if matches:
                        op_str = matches[-1]
                        # generated_text = generated_text.replace(op_str, json.dumps(op))
                        if "```json" in generated_text:
                            generated_text += "```\n"
                    crop_flag = True
                except Exception as e:
                    print(f"##DEBUG: - Error processing op: {e}")
                    generated_text += "Invalid operation!"
            # Add to the full response
            full_response += generated_text
            response_list.append(generated_text)
            if len(text_messages) == 1:
                text_messages.append(
                    {
                        "role": "assistant",
                        "content": [
                            {"type": "text", "text": response_list[-1]},
                        ],
                    },
                )
            else:
                text_messages[-1]['content'].append(
                    {"type": "text", "text": response_list[-1]},
                )

        # Check if we hit the iteration limit
        if iterations >= max_iterations:
            print('*' * 80)
            print(f"## ERROR: {self.device} Exceeded maximum iterations")
            print('*' * 80)
        return generated_ids, prompt_ids_len, full_response, img_messages, text_messages
    


if __name__ == "__main__":
    # Example usage
    min_pixels = 4*28*28
    max_pixels = 2048*28*28
    # model_path = "./ckpt/Qwen2.5-VL-7B-Instruct"
    # model_path = "./ckpt/Qwen2.5-VL-rl-0514"
    model_path = "./ckpt/Qwen2.5-VL-0514_1708-ck200"

    processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)

    vllm_device = "cuda:0"
    llm = LLM(
        model=model_path,
        device=vllm_device,
        gpu_memory_utilization=0.5,
        dtype=torch.bfloat16,
        limit_mm_per_prompt={"image": 16, "video": 10},
        enable_prefix_caching=True,
        enforce_eager=True,
        # Ensure that training and inference use the same processor for images.
        mm_processor_kwargs=(
            {
                "max_pixels": max_pixels,
                "min_pixels": min_pixels,
            }
        ),
        max_model_len=8192,
    )
    sampling_params = SamplingParams(
            temperature=0.0,
            top_p=0.95,
            top_k=50,
            max_tokens=1024,
            stop_token_ids=[],
            stop=["]}","]\n}","] }"],
        )
    agent = AgentVLMVLLM(min_pixels=min_pixels, max_pixels=max_pixels, temp_dir="./agent_crops_case")
    agent.set_model(llm, processor, sampling_params)

    image_path = "./assets/t2.png"
    question = "What is the blue thing behind the man sitting at the table with a blue plastic tablecloth?"
    question += "\nAnswer the question using a single word or phrase"
    generated_ids, prompt_ids_len, full_response, img_messages, text = agent.process(image_path, question)
    print("Full Response:", full_response)

    image_path = "./assets/10255.jpeg"
    question = "How many of the employees are very satisfied with their organization's COVID-19 response?"
    question += "\nAnswer the question using a single word or phrase"
    generated_ids, prompt_ids_len, full_response, img_messages, text = agent.process(image_path, question)
    print("Full Response:", full_response)
    