import torch
from PIL import Image
import cv2
from baseline import Baseline
from typing import List
import re
import numpy as np

class LLaVaBaseline(Baseline):
    def __init__(self, model_name: str, use_cot: bool = False, quantization_bits: int = None):
        """
        Initialize the Hugging Face-based Vision-Language Model with quantization.

        :param model_name: The Hugging Face model to load.
        :param use_cot: Whether to use chain-of-thought reasoning.
        :param quantization_bits: Quantization precision (4-bit or 8-bit).
        """
        super().__init__(use_cot)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = model_name

        if 'video' in model_name.lower():
            from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration, BitsAndBytesConfig
            if quantization_bits is not None:
                # Set up quantization config for 4-bit or 8-bit precision
                if quantization_bits == 4:
                    quantization_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.float16
                    )
                elif quantization_bits == 8:
                    quantization_config = BitsAndBytesConfig(
                        load_in_8bit=True,
                        bnb_8bit_compute_dtype=torch.float16
                    )
                else:
                    raise ValueError("Only 4-bit and 8-bit quantization are supported.")

                # Load video model with quantization config
                self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
                    self.model_name, quantization_config=quantization_config
                )
            else:
                # Load video model without quantization
                self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(self.model_name)
            self.processor = LlavaNextVideoProcessor.from_pretrained(self.model_name)
        else:
            from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
            if quantization_bits is not None:
                # Set up quantization config for 4-bit or 8-bit precision
                if quantization_bits == 4:
                    quantization_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.float16
                    )
                elif quantization_bits == 8:
                    quantization_config = BitsAndBytesConfig(
                        load_in_8bit=True,
                        bnb_8bit_compute_dtype=torch.float16
                    )
                else:
                    raise ValueError("Only 4-bit and 8-bit quantization are supported.")
                
                # Load model with quantization config
                self.model = LlavaForConditionalGeneration.from_pretrained(
                    self.model_name, quantization_config=quantization_config
                )
            else:
                # Load model without quantization
                self.model = LlavaForConditionalGeneration.from_pretrained(self.model_name)

            # Load processor
            if 'llava' in self.model_name.lower():
                self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-13b-hf")
            else:
                raise NotImplementedError
            self.processor.tokenizer.add_tokens(["<image>", "<pad>"], special_tokens=True)
            self.model.resize_token_embeddings(len(self.processor.tokenizer))

        self.model.to(self.device)
        self.baseline_type = 'local'
        
    # Use regex to extract everything to the right of 'ASSISTANT:'
    def extract_assistant_text(self, input_text):
        match = re.search(r'ASSISTANT:\s*(.*)', input_text, re.DOTALL)
        return match.group(1) if match else None

    def generate_text_individual(self, text: str, image_filepaths: List[str] = []) -> str:
        """
        Generate text based on the provided input text and images without using past conversations.

        :param text: The input text.
        :param image_filepaths: A list of image file paths.
        :return: The generated text.
        """
        if 'video' not in self.model_name.lower():
            # We only support 1 image currently
            assert len(image_filepaths) == 1, "HuggingFace VLMs (LLaVa, SpatialVLM) only support 1 image."
            
            # Load and process the image
            image = Image.fromarray(cv2.cvtColor(image_filepaths[-1], cv2.COLOR_BGR2RGB)).convert("RGB")
            
            prompt_conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": f"{text}"},
                    ],
                },
            ]

            # Format prompt with text and images
            prompt = self.processor.apply_chat_template(prompt_conversation, add_generation_prompt=True)

            # Tokenize inputs
            inputs = self.processor(text=prompt, images=image, padding=True, return_tensors="pt").to(self.model.device, torch.float16)

            # Generate output tokens
            generate_ids = self.model.generate(**inputs, max_new_tokens=2000)

            # Decode the generated output into text
            generated_text = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
            generated_text = self.extract_assistant_text(generated_text[0])
            
            # check if it's json, if not, try again..
            try:
                import json
                json.loads(generated_text)
                assert 'answer' in json.loads(generated_text)
            except:
                print("FAILED to return as json---------------trying again")
                print('prompt:', prompt_conversation)
                print('llm response:', generated_text)
                generated_text = self.generate_text_individual(text, image_filepaths)

            return generated_text
        else:
            # Load and process the images as frames of a video
            assert len(image_filepaths) > 1, "Please provide at least two image file paths for video frames."

            frames = []
            for img in image_filepaths:
                pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).convert("RGB")
                img = np.array(pil_img)
                frames.append(img)

            # Stack the frames to create a video tensor
            clips = np.stack(frames)  # Shape: (num_frames, height, width, 3)

            prompt_conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": f"{text}"},
                        {"type": "video"},
                    ],
                },
            ]

            # Format prompt with text and video
            prompt = self.processor.apply_chat_template(prompt_conversation, add_generation_prompt=True)

            # Tokenize inputs
            inputs = self.processor(
                text=prompt,
                videos=clips,
                padding=True,
                return_tensors="pt"
            ).to(self.model.device)

            # Generate output tokens
            generate_ids = self.model.generate(**inputs, max_new_tokens=2000)

            # Decode the generated output into text
            generated_text = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
            generated_text = self.extract_assistant_text(generated_text[0])

            # Check if it's JSON; if not, try again
            try:
                import json
                json.loads(generated_text)
                assert 'answer' in json.loads(generated_text)
            except:
                print("FAILED to return as JSON---------------trying again")
                print('Prompt:', prompt_conversation)
                print('LLM response:', generated_text)
                generated_text = self.generate_text_individual(text, image_filepaths)

            return generated_text

    def generate_text_using_past_conversations(self, text: str, image_filepaths: List[str] = []) -> str:
        """
        Generate text using the past conversations for chain-of-thought reasoning.

        :param text: The input text.
        :param image_filepaths: A list of image file paths.
        :return: The generated text.
        """
        # Combine the past conversations into a prompt
        raise NotImplementedError