"""
Evaluate generated images using Mask2Former (or other object detector model)
"""

import argparse
import json
import os
import re
import sys
import time
import redis
import uuid
import copy
import base64
from io import BytesIO
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from PIL import Image, ImageOps
import torch
import mmdet
from mmdet.apis import inference_detector, init_detector

import open_clip
from clip_benchmark.metrics import zeroshot_classification as zsc

from model.dpg.modelling_dpg import MPLUG
zsc.tqdm = lambda it, *args, **kwargs: it

# Get directory path

THRESHOLD = 0.3
COUNTING_THRESHOLD =  0.9
MAX_OBJECTS =  16
NMS_THRESHOLD = 1.0
POSITION_THRESHOLD = 0.1

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--imagedir", type=str)
    parser.add_argument("--outfile", type=str, default="results.jsonl")
    parser.add_argument("--model_config", type=str, default='../../mmdetection-3.x/configs/mask2former/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco.py')
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--clip_model_path", type=str, default="")
    # Other arguments
    parser.add_argument("--options", nargs="*", type=str, default=[])
    parser.add_argument("--device", type=str, default='cuda')
    args = parser.parse_args()
    args.options = dict(opt.split("=", 1) for opt in args.options)
    if args.model_config is None:
        args.model_config = os.path.join(
            os.path.dirname(mmdet.__file__),
            "../configs/mask2former/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco.py"
        )
    return args


def timed(fn):
    def wrapper(*args, **kwargs):
        startt = time.time()
        result = fn(*args, **kwargs)
        endt = time.time()
        print(f'Function {fn.__name__!r} executed in {endt - startt:.3f}s', file=sys.stderr)
        return result
    return wrapper

# Load models

@timed
def load_models(model_config='../../geneval_model/mmdetection-3.x/configs/mask2former/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco.py', 
                model_path="../../geneval_model/mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco.pth", 
                device='cuda', 
                clip_model_path='../../geneval_model/ViT-L-14.pt', 
                obj_path='../../geneval_model/object_names.txt'):
    
    CONFIG_PATH = model_config
    CKPT_PATH = model_path
    
    object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=device)

    clip_arch = "ViT-L-14"
    clip_model, _, transform = open_clip.create_model_and_transforms(clip_arch, pretrained=clip_model_path, device=device)
    tokenizer = open_clip.get_tokenizer(clip_arch)

    with open(obj_path) as cls_file:
        classnames = [line.strip() for line in cls_file]

    return object_detector, (clip_model, transform, tokenizer), classnames


COLORS = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"]
COLOR_CLASSIFIERS = {}

# Evaluation parts

class ImageCrops(torch.utils.data.Dataset):
    def __init__(self, image: Image.Image, objects, transform):
        self._image = image.convert("RGB")
        bgcolor = "#999"
        if bgcolor == "original":
            self._blank = self._image.copy()
        else:
            self._blank = Image.new("RGB", image.size, color=bgcolor)
        self._objects = objects
        self.transform = transform

    def __len__(self):
        return len(self._objects)

    def __getitem__(self, index):
        box, mask = self._objects[index]
        if mask is not None:
            assert tuple(self._image.size[::-1]) == tuple(mask.shape), (index, self._image.size[::-1], mask.shape)
            image = Image.composite(self._image, self._blank, Image.fromarray(mask))
        else:
            image = self._image
        if '1' == '1':
            image = image.crop(box[:4])

        return (self.transform(image), 0)
    

def relative_position(obj_a, obj_b):
    """Give position of A relative to B, factoring in object dimensions"""
    boxes = np.array([obj_a[0], obj_b[0]])[:, :4].reshape(2, 2, 2)
    center_a, center_b = boxes.mean(axis=-2)
    dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :]
    offset = center_a - center_b
    #
    revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset)
    if np.all(np.abs(revised_offset) < 1e-3):
        return set()
    #
    dx, dy = revised_offset / np.linalg.norm(offset)
    relations = set()
    if dx < -0.5: relations.add("left of")
    if dx > 0.5: relations.add("right of")
    if dy < -0.5: relations.add("above")
    if dy > 0.5: relations.add("below")
    return relations

def compute_iou(box_a, box_b):
    area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0)
    i_area = area_fn([
        max(box_a[0], box_b[0]), max(box_a[1], box_b[1]),
        min(box_a[2], box_b[2]), min(box_a[3], box_b[3])
    ])
    u_area = area_fn(box_a) + area_fn(box_b) - i_area
    return i_area / u_area if u_area else 0

DEVICE='cuda'

class GenEvalInf:
    def __init__(self, weight, device):
        self.device = device
        self.object_detector, (self.clip_model, self.transform, self.tokenizer), self.classnames = load_models(device=device)
        self.object_detector.eval()
        for n,p in self.object_detector.named_parameters():
            p.require_grad = False

        self.clip_model.eval()
        for n,p in self.clip_model.named_parameters():
            p.require_grad = False
        
        self.THRESHOLD = 0.6
        self.COUNTING_THRESHOLD =  0.9
        self.MAX_OBJECTS =  16
        self.NMS_THRESHOLD = 1.0
        self.POSITION_THRESHOLD = 0.1
        self.weight = weight

    def color_classification(self, image, bboxes, classname):
        if classname not in COLOR_CLASSIFIERS:
            COLOR_CLASSIFIERS[classname] = zsc.zero_shot_classifier(
                self.clip_model, self.tokenizer, COLORS,
                [
                    f"a photo of a {{c}} {classname}",
                    f"a photo of a {{c}}-colored {classname}",
                    f"a photo of a {{c}} object"
                ],
                device=self.device
            )
        clf = COLOR_CLASSIFIERS[classname]
        dataloader = torch.utils.data.DataLoader(
            ImageCrops(image, bboxes, self.transform),
            batch_size=16, num_workers=4
        )
        with torch.no_grad():
            pred, _ = zsc.run_classification(self.clip_model, clf, dataloader, device=self.device)
            return [COLORS[index.item()] for index in pred.argmax(1)]

    @torch.no_grad()
    def evaluate_image(self, image, objects, metadata):
        """
        Evaluate given image using detected objects on the global metadata specifications.
        Assumptions:
        * Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR
        * All clauses are independent, i.e., duplicating a clause has no effect on the correctness
        * CHANGED: Color and position will only be evaluated on the most confidently predicted objects;
            therefore, objects are expected to appear in sorted order
        """
        correct = True
        reason = []
        matched_groups = []
        # Check for expected objects
        for req in metadata.get('include', []):
            classname = req['class']
            matched = True
            found_objects = objects.get(classname, [])[:req['count']]
            if len(found_objects) < req['count']:
                correct = matched = False
                reason.append(f"expected {classname}>={req['count']}, found {len(found_objects)}")
            else:
                if 'color' in req:
                    # Color check
                    colors = self.color_classification(image, found_objects, classname)
                    if colors.count(req['color']) < req['count']:
                        correct = matched = False
                        reason.append(
                            f"expected {req['color']} {classname}>={req['count']}, found " +
                            f"{colors.count(req['color'])} {req['color']}; and " +
                            ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors)
                        )
                if 'position' in req and matched:
                    # Relative position check
                    expected_rel, target_group = req['position']
                    if matched_groups[target_group] is None:
                        correct = matched = False
                        reason.append(f"no target for {classname} to be {expected_rel}")
                    else:
                        for obj in found_objects:
                            for target_obj in matched_groups[target_group]:
                                true_rels = relative_position(obj, target_obj)
                                if expected_rel not in true_rels:
                                    correct = matched = False
                                    reason.append(
                                        f"expected {classname} {expected_rel} target, found " +
                                        f"{' and '.join(true_rels)} target"
                                    )
                                    break
                            if not matched:
                                break
            if matched:
                matched_groups.append(found_objects)
            else:
                matched_groups.append(None)
        # Check for non-expected objects
        for req in metadata.get('exclude', []):
            classname = req['class']
            if len(objects.get(classname, [])) >= req['count']:
                correct = False
                reason.append(f"expected {classname}<{req['count']}, found {len(objects[classname])}")
        return correct, "\n".join(reason)

    @torch.no_grad()
    def evaluate(self, input_image, metadatas):
        all_scores = []
        for filepath, metadata in zip(input_image, metadatas):
            result = inference_detector(self.object_detector, np.array(filepath))
            pred_instances = result.pred_instances
            bboxes = pred_instances.bboxes.cpu().numpy()
            labels = pred_instances.labels.cpu().numpy()
            scores = pred_instances.scores.cpu().numpy()
            segm = pred_instances.masks.cpu().numpy() if hasattr(pred_instances, 'masks') else None
            image = ImageOps.exif_transpose(filepath)
            detected = {}
            # Determine bounding boxes to keep
            confidence_threshold = self.THRESHOLD if metadata['tag'] != "counting" else self.COUNTING_THRESHOLD
            
            for index, classname in enumerate(self.classnames):
                # Filter detections for the current class
                class_indices = np.where(labels == index)[0]
                class_bboxes = bboxes[class_indices]
                class_scores = scores[class_indices]
                class_masks = segm[class_indices] if segm is not None else None

                # Sort by confidence score
                ordering = np.argsort(class_scores)[::-1]
                ordering = ordering[class_scores[ordering] > confidence_threshold]  # Apply threshold
                ordering = ordering[:self.MAX_OBJECTS]  # Limit number of detected objects per class

                detected[classname] = []
                while ordering.size > 0:
                    max_obj = ordering[0]
                    detected[classname].append((class_bboxes[max_obj], None if class_masks is None else class_masks[max_obj]))
                    ordering = ordering[1:]

                    # Apply Non-Maximum Suppression (NMS)
                    if self.NMS_THRESHOLD < 1:
                        ious = compute_iou(class_bboxes[max_obj], class_bboxes[ordering])
                        ordering = ordering[ious < self.NMS_THRESHOLD]
                if not detected[classname]:
                    del detected[classname]
            # Evaluate
            is_correct, reason = self.evaluate_image(image, detected, metadata)
            curscore = self.weight if is_correct else 0.0
            all_scores.append(curscore)
            
        return all_scores

    
class DPGEval:
    def __init__(self,
                 use_api: bool = False,
                 device = None,
                 ckpt: str = None,
                 host: str = '',
                 port: int = 0,
                 db: int = 0):
        self.use_api = use_api   
        if self.use_api:
            self.redis_client = redis.StrictRedis(host=host, port=port, db=db)
        else:
            self.device = device
            if 'mplug' in ckpt:
                self.vqa_model = MPLUG(ckpt=ckpt, device=self.device)
            else:
                raise NotImplementedError('Only support mplug and internvl now in dpg.')
            
        print('******Init dpg eval******')

    def _encoder_image_to_base64(self, images):
        if not isinstance(images, list):
            raise NotImplementedError('Only list of images is supported')
        
        base64_list = []
        with BytesIO() as buffer:
            for image in images:
                if isinstance(image, np.ndarray):
                    image = Image.fromarray(image)

                if not isinstance(image, Image.Image):
                    raise NotImplementedError('Only np.ndarray and PIL.Image is supported')
                
                image.save(buffer, format="png")
                img_bytes = buffer.getvalue()
                
                img_base64 = base64.b64encode(img_bytes).decode('utf-8')
                buffer.seek(0)
                buffer.truncate(0)
                base64_list.append(img_base64)
        
        return base64_list
    
    def _modelart_request(self, data_dict_list, redis_client):
        start_time = time.time()
        request_id = str(uuid.uuid4())
        
        redis_client.set(f"request:{request_id}", str(data_dict_list))
        
        redis_client.rpush('request_queue', request_id)
        print(f"Sent request with ID: {request_id}")
        
        progress_last = None
        while True:
            progress = redis_client.get(f'progress:{request_id}')
            if progress and (progress != progress_last):
                ep_time = time.time() - start_time

                progress_list = progress.decode().split('_')
                N_finished, perc = int(progress_list[0]), int(progress_list[1])
                print(f"{N_finished}/{len(data_dict_list)}, {perc}% , Time: {ep_time}")
                if perc == 100:
                    response = redis_client.get(f"response:{request_id}")
                    redis_client.delete(f"response:{request_id}")
                    redis_client.delete(f"progress:{request_id}")
                    redis_client.delete(f"request:{request_id}")
                    print(f"Received response for {request_id}")

                    return eval(response.decode())

            progress_last = progress
            time.sleep(1)

    def _evaluate_dpg_locally(self, image_list, value_list):
        score_list = []
        for idx, (image, value) in enumerate(zip(image_list, value_list)):
            score = self.vqa_model.evaluate_one_sample(value, image)
            score_list.append(score)
        
        return score_list


    def evaluate(self, image_list, value_list):
        if self.use_api:
            send_dict_list = []
            base64_list = self._encoder_image_to_base64(image_list)

            for kk in range(len(image_list)):
                send_dict = {"question_dict":value_list[kk], 
                            "image_base64":base64_list[kk], 
                            "idx":kk, 
                            'tasks':['temp001']
                            }
                send_dict_list.append(send_dict)
            
            tt_list = self._modelart_request(data_dict_list=send_dict_list, redis_client=self.redis_client)
            tt_list.sort(key=lambda x: x["idx"])

            scores = [float(tt['temp001']) for tt in tt_list]
        else:
            scores = self._evaluate_dpg_locally(image_list, value_list)
        return scores