import re
import torch
from abc import ABC, abstractmethod
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
from augmentation import image_augmentation, select_hp

class BaseGroundingModel(ABC):
    """
    Abstract base class for grounding models.
    """
    def __init__(self):
        self._load_model()

    @abstractmethod
    def _load_model(self):
        """Load the grounding model and any required processor."""
        pass

    @abstractmethod
    def predict(self, args: dict, image_path: str, hp: dict):
        """Return predicted bounding box (x1, y1, x2, y2)."""
        pass

class GrdModelCogVLM(BaseGroundingModel):
    def _load_model(self):
        self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
        self.model_grounding = AutoModelForCausalLM.from_pretrained(
            'THUDM/cogvlm-grounding-base-hf',
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        ).to('cuda:0').eval()

    def predict(self, args: dict, image_path: str, hp: dict):
        query = args.grounding_prompt_list[hp['grd_prompt_id']]
        image = image_augmentation(image_path, select_hp(hp, 'grd'))
        image = Image.fromarray(image)
        w, h = image.size

        inputs = self.model_grounding.build_conversation_input_ids(
            self.tokenizer, query=query, images=[image])

        inputs = {
            'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
            'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
            'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
            'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]]
        }

        gen_kwargs = {"max_length": 2048, "do_sample": False}
        with torch.no_grad():
            outputs = self.model_grounding.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs['input_ids'].shape[1]:]

        response = self.tokenizer.decode(outputs[0])
        bbox_match = re.search(r"\[([\d\s.,]+)\]", response)
        x1, y1, x2, y2 = 0, 0, w, h

        if bbox_match:
            coords = [float(x) for x in bbox_match.group(1).replace(',', ' ').split()]
            if len(coords) == 4:
                x1, y1, x2, y2 = coords
                x1 *= (w / 1000.0)
                x2 *= (w / 1000.0)
                y1 *= (h / 1000.0)
                y2 *= (h / 1000.0)

        return x1, y1, x2, y2
