import cv2
import torch
import re
import numpy as np
from typing import List, Union
import nltk
import inflect
from transformers import AutoTokenizer
from torchvision import transforms as T
import pdb
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark import layers as L
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
from maskrcnn_benchmark.utils import cv2_util

engine = inflect.engine()

try:
    import spacy
    spacy_tokenzier = spacy.load('en_core_web_trf')
except:
    print("can't find spacy package, use nltk instead")
    nltk.download('punkt')
    nltk.download('averaged_perceptron_tagger')
    
import timeit


class GLIPDemo(object):
    def __init__(self,
                 cfg,
                 confidence_threshold=0.7,
                 min_image_size=None,
                 show_mask_heatmaps=False,
                 masks_per_dim=5,
                 load_model=True
                 ):
        self.cfg = cfg.clone()
        if load_model:
            self.model = build_detection_model(cfg)
            self.model.eval()
            self.device = torch.device(cfg.MODEL.DEVICE)
            self.model.to(self.device)
        self.min_image_size = min_image_size
        self.show_mask_heatmaps = show_mask_heatmaps
        self.masks_per_dim = masks_per_dim

        save_dir = cfg.OUTPUT_DIR
        if load_model:
            checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
            _ = checkpointer.load(cfg.MODEL.WEIGHT)

        self.transforms = self.build_transform()

        # used to make colors for each tokens
        mask_threshold = -1 if show_mask_heatmaps else 0.5
        self.masker = Masker(threshold=mask_threshold, padding=1)
        self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
        self.cpu_device = torch.device("cpu")
        self.confidence_threshold = confidence_threshold

        self.tokenizer = self.build_tokenizer()

    def build_transform(self):
        """
        Creates a basic transformation that was used to train the models
        """
        cfg = self.cfg

        # we are loading images with OpenCV, so we don't need to convert them
        # to BGR, they are already! So all we need to do is to normalize
        # by 255 if we want to convert to BGR255 format, or flip the channels
        # if we want it to be in RGB in [0-1] range.
        if cfg.INPUT.TO_BGR255:
            to_bgr_transform = T.Lambda(lambda x: x * 255)
        else:
            to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])

        normalize_transform = T.Normalize(
            mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
        )

        transform = T.Compose(
            [
                T.ToPILImage(),
                T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x: x,
                T.ToTensor(),
                to_bgr_transform,
                normalize_transform,
            ]
        )
        return transform

    def build_tokenizer(self):
        cfg = self.cfg
        tokenizer = None
        if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased":
            tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
            from transformers import CLIPTokenizerFast
            if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
                tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
                                                              from_slow=True, mask_token='ðŁĴĳ</w>')
            else:
                tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
                                                              from_slow=True)
        return tokenizer

    def run_ner(self, caption):
        noun_phrases = find_noun_phrases(caption)
        noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
        noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
        relevant_phrases = noun_phrases
        labels = noun_phrases
        self.entities = labels

        tokens_positive = []

        for entity, label in zip(relevant_phrases, labels):
            try:
                # search all occurrences and mark them as different entities
                for m in re.finditer(entity, caption.lower()):
                    if len(tokens_positive) > 79:
                        break
                    tokens_positive.append([[m.start(), m.end()]])
            except:
                print("noun entities:", noun_phrases)
                print("entity:", entity)
                print("caption:", caption.lower())

        return tokens_positive
    
    def run_spacy_ner(self, caption):
        doc = spacy_tokenzier(caption)
        
        # pdb.set_trace()
        tokens_positive = []
        
        # add some rules
        noun_chunks = list(doc.noun_chunks)
        
        # some caption may use these phrases to summerize the whole image at the begainning or ending
        # or this pic/fig is created by someone, which is a common pattern
        wide_nouns = [
            "image", "images", "photo", "photos", "description", "descriptions", "picture", "pictures", 
            "snapshot", "snapshots", "photograph", "photographs", "shot", "shots", "pic", "pics", 
            "portrayal", "portrayals", "illustration", "illustrations", "visual", "visuals", 
            "view", "views", "perspective", "perspectives", "explanation", "explanations", 
            "depiction", "depictions", "rendering", "renderings", "representation", "fig.", 'fig',
            "representations", "photography", "snapshotting", "shooting", "snapshotting", 
            "capture", "captures", "visualisation", "visualisations", "screenshot", "screenshots",
            "painting", "paintings", "drawing", "drawings", "sketch", "sketches"]
        if len(noun_chunks) > 2:
            if noun_chunks[0].root.text.lower() in wide_nouns:
                noun_chunks = noun_chunks[1:]
            elif noun_chunks[-1].root.text.lower() in wide_nouns:
                noun_chunks = noun_chunks[:-1]
            elif noun_chunks[-2].root.text.lower() in wide_nouns:
                noun_chunks = noun_chunks[:-2] + [noun_chunks[-1],]
        
        # filter some abstract nouns, which is hard to find an entity in a real image
        abstract_nouns = [
            "time", "life", "love", "freedom", "happiness", 
            "wisdom", "peace", "justice", "hope", "courage", 
            "faith", "understanding", "advantage", "pursuit"
        ]
        
        filtered_noun_chunks = []
        for chunk in noun_chunks:
            # if chunk.text.lower() not in abstract_nouns and not any(tok.pos_ == "PRON" for tok in chunk):
            if chunk.root.text.lower() not in abstract_nouns:
                filtered_noun_chunks.append(chunk)
                
        # filter some entity representing number of things/persons
        merged_noun_chunks = []
        quant_nouns = [
            "group", "dozens", "cluster", "collection", "groups", 
            "clusters", "batch", "bunch", "bundle", "crowd", "flock", 
            "crowds", "gang", "heap", "herd", "set", "mass", "stack", 
            "pile", "horde", "swarm", "army", "throng", "array", 
            "multitude", "host", "bevy", "pack", "ensemble", 
            "assortment", "conglomeration", "bouquet"
        ]
        i = 0
        while i < len(filtered_noun_chunks):
            # pdb.set_trace()
            chunk = filtered_noun_chunks[i]
            
            # we ignore the measure word
            if i < len(filtered_noun_chunks) - 2 and chunk.root.text.lower() in quant_nouns and f"{chunk.text} of" in caption:
                print(f"ignore measure nouns: {chunk.root.text}")
                i += 1
            else:
                merged_noun_chunks.append(chunk.text)
                tokens_positive.append([[doc[chunk.start:chunk.end].start_char, doc[chunk.start:chunk.end].end_char]])
                i += 1
                
        self.entities = merged_noun_chunks
                
        if len(tokens_positive) == 0:
            self.entities = ["object"]
            tokens_positive = [[[0, 1]]]
            
        return tokens_positive

    def inference(self, original_image, original_caption):
        predictions = self.compute_prediction(original_image, original_caption)
        top_predictions = self._post_process_fixed_thresh(predictions)
        return top_predictions

    def run_on_web_image(self, 
            original_image, 
            original_caption, 
            thresh=0.5,
            custom_entity = None,
            alpha = 0.0,
            color = 255,
            text_size = 1,
            text_pixel = 2,
            box_pixel = 3):
        # pdb.set_trace()
        self.color = getattr(self, 'color', color)
        predictions = self.compute_prediction(original_image, original_caption, custom_entity)
        top_predictions = self._post_process(predictions, thresh)

        result = original_image.copy()
        if self.show_mask_heatmaps:
            return self.create_mask_montage(result, top_predictions)
        result = self.overlay_boxes(result, top_predictions, box_pixel=box_pixel)
        result = self.overlay_entity_names(result, top_predictions, text_size=text_size, text_pixel=text_pixel)
        if self.cfg.MODEL.MASK_ON:
            result = self.overlay_mask(result, top_predictions)
        return result, top_predictions

    def visualize_with_predictions(self, 
            original_image, 
            predictions, 
            thresh=0.5,
            alpha=0.0,
            box_pixel=3,
            text_size = 1,
            text_pixel = 2,
            text_offset = 10,
            text_offset_original = 4,
            color = 255):
        self.color = color
        height, width = original_image.shape[:-1]
        predictions = predictions.resize((width, height))
        top_predictions = self._post_process(predictions, thresh)

        result = original_image.copy()
        if self.show_mask_heatmaps:
            return self.create_mask_montage(result, top_predictions)
        # result = self.overlay_boxes(result, top_predictions, alpha=alpha, box_pixel=box_pixel)
        result = self.overlay_red_boxes(result, top_predictions, alpha=alpha, box_pixel=box_pixel)
        result = self.overlay_entity_names(result, top_predictions, text_size=text_size, text_pixel=text_pixel, text_offset = text_offset, text_offset_original = text_offset_original)
        if self.cfg.MODEL.MASK_ON:
            result = self.overlay_mask(result, top_predictions)
        return result, top_predictions

    def compute_prediction(self, original_image, original_caption, custom_entity = None):
        # image
        image = self.transforms(original_image)
        image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
        image_list = image_list.to(self.device)
        
        # pdb.set_trace()

        tokenized = self.tokenizer([original_caption], return_tensors="pt")
        tokens_positive = self.run_spacy_ner(original_caption) 
        self.tokens_positive = tokens_positive
        
        # process positive map
        positive_map = create_positive_map(tokenized, tokens_positive)

        if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
            plus = 1
        else:
            plus = 0

        positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus)
        self.plus = plus
        self.positive_map_label_to_token = positive_map_label_to_token

        # compute predictions
        with torch.no_grad():
            predictions = self.model(image_list, captions=[original_caption], positive_map=positive_map_label_to_token)
            predictions = [o.to(self.cpu_device) for o in predictions]

        # always single image is passed at a time
        prediction = predictions[0]

        # pdb.set_trace()
        # reshape prediction (a BoxList) into the original image size
        height, width = original_image.shape[:-1]
        prediction = prediction.resize((width, height))

        return prediction

    def _post_process_fixed_thresh(self, predictions):
        scores = predictions.get_field("scores")
        labels = predictions.get_field("labels").tolist()
        thresh = scores.clone()
        for i, lb in enumerate(labels):
            if isinstance(self.confidence_threshold, float):
                thresh[i] = self.confidence_threshold
            elif len(self.confidence_threshold) == 1:
                thresh[i] = self.confidence_threshold[0]
            else:
                thresh[i] = self.confidence_threshold[lb - 1]
        keep = torch.nonzero(scores > thresh).squeeze(1)
        predictions = predictions[keep]

        scores = predictions.get_field("scores")
        _, idx = scores.sort(0, descending=True)
        return predictions[idx]

    def _post_process(self, predictions, threshold=0.5):
        scores = predictions.get_field("scores")
        labels = predictions.get_field("labels").tolist()
        thresh = scores.clone()
        for i, lb in enumerate(labels):
            if isinstance(self.confidence_threshold, float):
                thresh[i] = threshold
            elif len(self.confidence_threshold) == 1:
                thresh[i] = threshold
            else:
                thresh[i] = self.confidence_threshold[lb - 1]
        keep = torch.nonzero(scores > thresh).squeeze(1)
        predictions = predictions[keep]

        scores = predictions.get_field("scores")
        _, idx = scores.sort(0, descending=True)
        return predictions[idx]

    def compute_colors_for_labels(self, labels):
        """
        Simple function that adds fixed colors depending on the class
        """
        colors = (30 * (labels[:, None] - 1) + 1) * self.palette
        colors = (colors % 255).numpy().astype("uint8")
        try:
            colors = (colors * 0 + self.color).astype("uint8")
        except:
            pass
        return colors

    def overlay_boxes(self, image, predictions, alpha=0.5, box_pixel = 3):
        labels = predictions.get_field("labels")
        boxes = predictions.bbox

        colors = self.compute_colors_for_labels(labels).tolist()
        new_image = image.copy()
        for box, color in zip(boxes, colors):
            box = box.to(torch.int64)
            top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
            new_image = cv2.rectangle(
                new_image, tuple(top_left), tuple(bottom_right), tuple(color), box_pixel)

        # Following line overlays transparent rectangle over the image
        image = cv2.addWeighted(new_image, alpha, image, 1 - alpha, 0)

        return image

    def overlay_red_boxes(self, image, predictions, alpha=0.5, box_pixel = 3):
        # labels = predictions.get_field("labels")
        if isinstance(predictions, BoxList):
            boxes = predictions.bbox
        else:
            boxes = predictions

        new_image = image.copy()
        for box in boxes:
            if isinstance(box, torch.Tensor):
                box = box.to(torch.int64)
            else:
                box = box.astype(np.int)
            top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
            new_image = cv2.rectangle(
                new_image, tuple(top_left), tuple(bottom_right), (0, 0, 255), box_pixel)

        # Following line overlays transparent rectangle over the image
        # image = cv2.addWeighted(new_image, alpha, image, 1 - alpha, 0)

        return new_image
    
    def overlay_scores(self, image, predictions):
        scores = predictions.get_field("scores")
        boxes = predictions.bbox

        for box, score in zip(boxes, scores):
            box = box.to(torch.int64)
            image = cv2.putText(image, '%.3f' % score,
                                (int(box[0]), int((box[1] + box[3]) / 2)),
                                cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)

        return image

    def overlay_entity_names(self, image, predictions, names=None, text_size=1.0, text_pixel=2, text_offset = 10, text_offset_original = 4):
        # pdb.set_trace()
        if isinstance(predictions, BoxList):
            scores = predictions.get_field("scores").tolist()
            labels = predictions.get_field("labels").tolist()
            boxes = predictions.bbox
        else:
            scores = predictions['scores']
            labels = predictions['labels']
            boxes = predictions['boxes']
        
        if names is None:
            new_labels = []
            if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
                plus = 1
            else:
                plus = 0
            self.plus = plus
            if self.entities and self.plus:
                for i in labels:
                    if i <= len(self.entities):
                        new_labels.append(self.entities[i - self.plus])
                    else:
                        new_labels.append('object')
                # labels = [self.entities[i - self.plus] for i in labels ]
            else:
                new_labels = ['object' for i in labels]
        else:
            new_labels = names

        template = "{}:{:.2f}"
        previous_locations = []
        for box, score, label in zip(boxes, scores, new_labels):
            x, y = box[:2]
            s = template.format(label, score).replace("_", " ").replace("(", "").replace(")", "")
            for x_prev, y_prev in previous_locations:
                if abs(x - x_prev) < abs(text_offset) and abs(y - y_prev) < abs(text_offset):
                    y -= text_offset
                    
            if y < 2 * text_offset:
                y += text_offset + text_offset_original

            cv2.putText(
                image, s, (int(x), int(y)-text_offset_original), cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 255, 0), text_pixel, cv2.LINE_AA
            )
            previous_locations.append((int(x), int(y)))


        return image

    def overlay_mask(self, image, predictions):
        masks = predictions.get_field("mask").numpy()
        labels = predictions.get_field("labels")

        colors = self.compute_colors_for_labels(labels).tolist()

        # import pdb
        # pdb.set_trace()
        # masks = masks > 0.1

        for mask, color in zip(masks, colors):
            thresh = mask[0, :, :, None].astype(np.uint8)
            contours, hierarchy = cv2_util.findContours(
                thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
            )
            image = cv2.drawContours(image, contours, -1, color, 2)

        composite = image

        return composite

    def create_mask_montage(self, image, predictions):
        masks = predictions.get_field("mask")
        masks_per_dim = self.masks_per_dim
        masks = L.interpolate(
            masks.float(), scale_factor=1 / masks_per_dim
        ).byte()
        height, width = masks.shape[-2:]
        max_masks = masks_per_dim ** 2
        masks = masks[:max_masks]
        # handle case where we have less detections than max_masks
        if len(masks) < max_masks:
            masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
            masks_padded[: len(masks)] = masks
            masks = masks_padded
        masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
        result = torch.zeros(
            (masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8
        )
        for y in range(masks_per_dim):
            start_y = y * height
            end_y = (y + 1) * height
            for x in range(masks_per_dim):
                start_x = x * width
                end_x = (x + 1) * width
                result[start_y:end_y, start_x:end_x] = masks[y, x]

        return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET), None


def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0):
    positive_map_label_to_token = {}
    for i in range(len(positive_map)):
        positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
    return positive_map_label_to_token


def create_positive_map(tokenized, tokens_positive):
    """construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
    positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
    
    # pdb.set_trace()
    for j, tok_list in enumerate(tokens_positive):
        if not isinstance(tok_list, List) or len(tok_list[0]) != 2:
            print("Usual tokens_positive: ", tok_list)
        for (beg, end) in tok_list:
            try:
                beg_pos = tokenized.char_to_token(beg)
                end_pos = tokenized.char_to_token(end - 1)
            except Exception as e:
                print("beg:", beg, "end:", end)
                print("token_positive:", tokens_positive)
                # print("beg_pos:", beg_pos, "end_pos:", end_pos)
                raise e
            if beg_pos is None:
                try:
                    beg_pos = tokenized.char_to_token(beg + 1)
                    if beg_pos is None:
                        beg_pos = tokenized.char_to_token(beg + 2)
                except:
                    beg_pos = None
            if end_pos is None:
                try:
                    end_pos = tokenized.char_to_token(end - 2)
                    if end_pos is None:
                        end_pos = tokenized.char_to_token(end - 3)
                except:
                    end_pos = None
            if beg_pos is None or end_pos is None:
                continue

            assert beg_pos is not None and end_pos is not None
            positive_map[j, beg_pos: end_pos + 1].fill_(1)
    return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)


def find_noun_phrases(caption: str) -> List[str]:
    caption = caption.lower()
    tokens = nltk.word_tokenize(caption)
    pos_tags = nltk.pos_tag(tokens)

    grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
    cp = nltk.RegexpParser(grammar)
    result = cp.parse(pos_tags)

    noun_phrases = list()
    for subtree in result.subtrees():
        if subtree.label() == 'NP':
            noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))

    return noun_phrases


def remove_punctuation(text: str) -> str:
    punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^',
             '\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
             ]
    for p in punct:
        text = text.replace(p, '')
    return text.strip()