import torch
from PIL import Image, ImageDraw, ImageOps
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import json
import pdb
import cv2
import numpy as np
from typing import Union

from .tools import is_platform_win
from .base_captioner import BaseCaptioner
import torch.distributed as dist


class BLIP2Captioner(BaseCaptioner):
    def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
        super().__init__(device, enable_filter)
        self.device = device
        self.dialogue = dialogue
        self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
        self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
        if is_platform_win():
            self.model = Blip2ForConditionalGeneration.from_pretrained(
                "Salesforce/blip2-opt-2.7b",
                device_map="sequential",
                torch_dtype=self.torch_dtype,
            )
        else:
            # dist.barrier()
            self.model = Blip2ForConditionalGeneration.from_pretrained(
                "Salesforce/blip2-opt-2.7b", device_map="sequential", load_in_8bit=True
            )

    @torch.no_grad()
    def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
        if type(image) == str:  # input path
            image = Image.open(image)

        if not self.dialogue:
            text_prompt = "Question: what does the image show? Answer:"
            inputs = self.processor(image, text=text_prompt, return_tensors="pt").to(
                self.device, self.torch_dtype
            )
            out = self.model.generate(**inputs, max_new_tokens=50)
            captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
            if self.enable_filter and filter:
                captions = self.filter_caption(image, captions)
            # print(f"\n{captions}")
            return captions
        else:
            context = []
            template = "Question: {} Answer: {}."
            while True:
                input_texts = input()
                if input_texts == "end":
                    break
                prompt = (
                    " ".join(
                        [
                            template.format(context[i][0], context[i][1])
                            for i in range(len(context))
                        ]
                    )
                    + " Question: "
                    + input_texts
                    + " Answer:"
                )
                inputs = self.processor(image, text=prompt, return_tensors="pt").to(
                    self.device, self.torch_dtype
                )
                out = self.model.generate(**inputs, max_new_tokens=50)
                captions = self.processor.decode(
                    out[0], skip_special_tokens=True
                ).strip()
                context.append((input_texts, captions))

        return captions


if __name__ == "__main__":
    dialogue = False
    model = BLIP2Captioner(
        device="cuda:4",
        dialogue=dialogue,
        cache_dir="/nvme-ssd/fjj/Caption-Anything/model_cache",
    )
    image_path = "test_img/img2.jpg"
    seg_mask = np.zeros((224, 224))
    seg_mask[50:200, 50:200] = 1
    print(f"process image {image_path}")
    print(model.inference_seg(image_path, seg_mask))
