import os
import re
import cv2
import torch
import random
import tempfile
import numpy as np
from PIL import Image
from abc import ABC, abstractmethod
from mistralai import Mistral
from utils import apply_mask_to_image, image_to_base64, extract_patches_from_mask, sample_crops_by_area
from open_clip import create_model_from_pretrained, get_tokenizer
import torch.nn.functional as F

device = "cuda:0"

class BaseValidator(ABC):
    """
    Abstract base class for validation models.
    """
    def __init__(self):
        self._load_model()

    @abstractmethod
    def _load_model(self):
        pass

    @abstractmethod
    def predict(self, args: dict, image_path: str, mask: np.ndarray) -> dict:
        pass

    def create_question(self, args):
        classes = [args.prompt_target] + args.contrastive_parts
        random.shuffle(classes)
        n_classes = len(classes)
        classes_text = ''.join([('%d) %s; ' % (i+1, c)) for i, c in enumerate(classes)])
        ground_truth = classes.index(args.prompt_target) + 1
        question = (
            f"The image is a part extracted from a whole image of {args.prompt_image}."
            " Your task is to guess what the non-background region of the image is."
            f" The answer is one from the following candidates: {classes_text}{n_classes+1}) {args.prompt_image}; {n_classes+2}) nothing reasonable;"
            " Your answer should start from the sentence:\n"
            "\"The result is: $CID. The reason is\", where $CID is the index of the above candidates starting from 1."
        )
        return question, ground_truth

class ValModelMistral(BaseValidator):
    def _load_model(self):
        self.model_validator = Mistral(api_key="vR7Hx6CGs4YzBkofpc71Jx7KODWYoJ6r")

    def predict(self, args: dict, image_path: str, mask: np.ndarray) -> dict:
        question, ground_truth = self.create_question(args)
        inputs = [{'type': 'text', 'text': question}]
        masked_image = apply_mask_to_image(image_path, mask)
        base64_image = image_to_base64(masked_image)
        inputs.append({'type': 'image_url', 'image_url': f"data:image/jpeg;base64,{base64_image}"})
        messages = [{"role": "user", "content": inputs}]
        chat_response = self.model_validator.chat.complete(model="pixtral-12b-2409", messages=messages)
        msg = chat_response.choices[0].message.content
        match = re.search(r"The result is:\s*(\d+)", msg)
        predict = int(match.group(1))
        return {'score': predict == ground_truth, 'question': question, 'reason': msg}

class ValModelLlaVa(BaseValidator):
    def _load_model(self):
        self.url = "http://localhost:8000/llava/infer"

    def predict(self, args: dict, image_path: str, mask: np.ndarray) -> dict:
        question, ground_truth = self.create_question(args)
        masked_image = apply_mask_to_image(image_path, mask)
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
            temp_image_path = tmp.name
            cv2.imwrite(temp_image_path, masked_image)
        with open(temp_image_path, "rb") as img_file:
            files = {"image": img_file}
            data = {"prompt": question}
            response = requests.post(self.url, data=data, files=files)
        msg = response.json()['response']
        match = re.search(r"The result is:\s*(\d+)", msg)
        predict = int(match.group(1))
        return {'score': predict == ground_truth, 'question': question, 'reason': msg}

class ValModelBioMedCLIP(BaseValidator):
    def _load_model(self):
        # Load the model and config files from the Hugging Face Hub
        self.model, self.preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
        self.tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
        self.model.to(device)
        self.model.eval()

    def predict(self, args, image_path: str, mask: np.ndarray) -> dict:
        image = cv2.imread(image_path)
        image_area = image.shape[0] * image.shape[1]

        mask_area = np.sum(mask)
        if mask_area < 0.5 * image_area:
            try:
                masked_image, _ = extract_patches_from_mask(image_path, mask)
            except:
                return {'score': 0, 'question': '', 'reason': 'unexpected error in extract_patches_from_mask'}
            patch_h, patch_w, _ = masked_image.shape
            target_area = patch_h * patch_w
        else:
            masked_image = apply_mask_to_image(image_path, mask)
            target_area = mask_area

        ## zero-shot classification score
        classes = [args.prompt_target] + args.contrastive_parts
        template = ' in a %s image' % args.prompt_image

        image_inputs = [Image.fromarray(masked_image)]
        image_inputs += sample_crops_by_area(image_path, target_area)

        context_length = 256
        images = torch.stack([self.preprocess(img) for img in image_inputs]).to(device)
        texts = self.tokenizer([l + template for l in classes], context_length=context_length).to(device)
        with torch.no_grad():
            image_features, text_features, logit_scale = self.model(images, texts)

        logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
        sorted_indices = torch.argsort(logits, dim=-1, descending=True)

        probabilities = logits.cpu()
        target_scores = probabilities[:, 0]
        score_cls = target_scores[0].item()

        ## image-text matching score
        descriptions = args.validation_prompt_list
        texts = self.tokenizer(descriptions, context_length=context_length).to(device)
        with torch.no_grad():
            image_features, text_features, logit_scale = self.model(images, texts)

        similarity = (logit_scale * image_features @ text_features.t()).detach()

        # Compute cosine similarity
        similarity = similarity.detach().mean(dim = 1)
        similarity = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-6)
        score_matching = similarity[0].item() / (similarity.sum().item() + 1e-6)

        score = score_cls + score_matching

        #rule-based check
        image = cv2.imread(image_path)
        h, w, _ = image.shape
        image_area = h * w
        mask_area = np.sum(mask)
        p_min, p_max = args.area_range
        p = 1.0 * mask_area / image_area
        if p > p_max or p < p_min:
            score = 0

        top_left = (int(args.center_x_range[0] * w), int(args.center_y_range[0] * h))
        bottom_right = (int(args.center_x_range[1] * w), int(args.center_y_range[1] * h))
        x1, y1 = top_left
        x2, y2 = bottom_right

        total_positive = np.sum(mask == 1)
        region = mask[y1:y2, x1:x2]
        positive_in_region = np.sum(region == 1)
        r = 1.0 * positive_in_region / (1 + total_positive)
        if r < 0.5:
            score = 0

        reason = 'cls score: %.4f, matching score: %.4f, total: %.2f\n' % (score_cls, score_matching, score)
        reason += 'Cls Probs: %s\n' % (str(probabilities[:,0].cpu().numpy().tolist()))
        reason += 'Matching: %s\n' % (str(similarity.cpu().numpy().tolist()))
        reason += 'Abnormal Check: p=%.4f (%.2f, %.2f), r=%.4f' % (p, p_min, p_max, r)

        return {'score': score, 'question': str(classes), 'reason': reason}



