import numpy as np
import torch

from SGM.textGraphParsor_utils import get_graph_phrases
from PIL import Image

from scipy import stats as sts
import time
import alpha_clip
from torchvision import transforms

from old_inference import detect_mask


class HICEScore():
    def __init__(self, vlm_model, preprocess,  sg_model, mobilesamv2, ObjAwareModel,predictor, device, args=None):
        self.args=args
        self.vlm_model = vlm_model
        self.preprocess = preprocess
        self.sg_model = sg_model
        self.mobilesamv2 = mobilesamv2
        self.ObjAwareModel = ObjAwareModel
        self.predictor = predictor
        self.mask_transform = transforms.Compose([
            # transforms.ToTensor(),
            transforms.Resize((224, 224)),  # change to (336,336) when using ViT-L/14@336px
            transforms.Normalize(0.5, 0.26)
        ])
        self.device = device

    def __call__(self, image_path, pred_captions, reference_captions=None, res_dict=None, verbose=False):
        # reference SG
        raw_image = Image.open(image_path).convert("RGB")
        full_image = self.preprocess(raw_image).unsqueeze(0).half().to(self.device)
        # global image-text similarity
        full_mask = torch.ones_like(full_image)*1
        full_alpha = self.mask_transform(full_mask[:,0])
        full_alpha = full_alpha.half().unsqueeze(dim=0).to(self.device)
        pred_texts = alpha_clip.tokenize(pred_captions).to(self.device)
        with torch.no_grad():
            image_features = self.vlm_model.visual(full_image, full_alpha)
            text_features = self.vlm_model.encode_text(pred_texts)

        # normalize
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        ## print the result
        global_scores = 2.5 * image_features @ text_features.T

        # compute local image-text similarity
        pred_cut_captions = [[s for s in pred.split('. ') if s!=''] for pred in pred_captions]
        pred_cut_flatten_captions = [ref for ref_list in pred_cut_captions for ref in ref_list]
        pred_flatten_graphs = self.sg_model.parse(pred_cut_flatten_captions)
        pred_lens = [len(pred_cut_captions[idx]) for idx in range(len(pred_cut_captions))]
        start_id = 0
        pred_graphs = []
        for pred_len in pred_lens:
            pred_graphs.append(pred_flatten_graphs[start_id:start_id + pred_len])
            start_id += pred_len
        type_dict = {}
        local_texts = get_graph_phrases(pred_graphs[0], type_dict)
        print("Local texts:", local_texts)
        local_texts_ = alpha_clip.tokenize(local_texts).to(self.device)

        region_alphas = detect_mask([image_path], self.mobilesamv2, self.ObjAwareModel, self.predictor, self.mask_transform, self.args, self.device,verbose=True)
        region_alphas_ = torch.cat(region_alphas)
        with torch.no_grad():
            image_features = self.vlm_model.visual(full_image.repeat(region_alphas_.shape[0],1,1,1), region_alphas_)
            text_features = self.vlm_model.encode_text(local_texts_)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        ## print the result
        local_scores = 2.5 * image_features @ text_features.T
        local_precisions = local_scores.max(0)[0]
        local_recalls = local_scores.max(1)[0]

        return global_scores.mean(), local_precisions.mean(), local_recalls.mean()

    def batch_compute_scores(self,image_paths, candidate_captions, reference_captions=None, res_dict=None, verbose=False):
        if reference_captions is not None:
            return self.batch_compute_scores_with_references(image_paths, candidate_captions, reference_captions)
        else:
            return self.batch_compute_scores_without_references(image_paths, candidate_captions)

    def batch_compute_scores_without_references(self,image_paths, candidate_captions, res_dict=None, verbose=False):
        assert len(image_paths) == len(candidate_captions)
        batch_size = len(image_paths)
        # reference SG
        t1 = time.time()
        raw_images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
        full_images = [self.preprocess(raw_image).unsqueeze(0).half().to(self.device) for raw_image in raw_images]
        full_images = torch.cat(full_images)
        t2 = time.time()
        print("prepare full images", t2 - t1)
        # global image-text similarity
        full_mask = torch.ones_like(full_images[0].unsqueeze(0)) * 1
        full_alpha = self.mask_transform(full_mask[:, 0])
        full_alpha = full_alpha.half().unsqueeze(dim=0).to(self.device)
        global_textids = alpha_clip.tokenize(candidate_captions).to(self.device)
        t3 = time.time()
        print("prepare full images", t3 - t2)
        # compute local image-text similarity

        pred_cut_captions = [[s for s in pred.split('. ') if s!=''] for pred in candidate_captions]
        pred_cut_flatten_captions = [ref for ref_list in pred_cut_captions for ref in ref_list]
        pred_flatten_graphs = self.sg_model.parse(pred_cut_flatten_captions)
        pred_lens = [len(pred_cut_captions[idx]) for idx in range(len(pred_cut_captions))]
        start_id = 0
        pred_graphs = []
        for pred_len in pred_lens:
            pred_graphs.append(pred_flatten_graphs[start_id:start_id + pred_len])
            start_id += pred_len

        type_dict = {}
        local_texts = [get_graph_phrases(pred_graphs[idx], type_dict) for idx in range(len(pred_graphs))]
        t4 = time.time()
        print("parse captions", t4 - t3)
        # print("Local texts:", local_texts)
        batch_local_textids = [alpha_clip.tokenize(local_text).to(self.device) for local_text in local_texts]
        batch_local_textnums = [batch_local_textids[idx].shape[0] for idx in range(len(batch_local_textids))]
        batch_local_textids = torch.cat(batch_local_textids)

        t5 = time.time()
        print("tokenize caption", t5 - t4)

        batch_detect_boxes = detect_batch_imgs(self.detector, self.detector_transforms, image_paths)
        t6 = time.time()
        print("detect images", t6 - t5)
        batch_region_alphas = [
            convert_boxes_to_alpha(batch_detect_boxes[idx], self.mask_transform, raw_images[idx].size[1],
                                   raw_images[idx].size[0]) for idx in range(len(raw_images))]
        batch_local_regionnums = [batch_region_alphas[idx].shape[0] for idx in range(len(batch_region_alphas))]
        batch_region_alphas = torch.cat(batch_region_alphas)
        batch_region_images = [full_images[idx].repeat(batch_local_regionnums[idx], 1, 1, 1) for idx in
                               range(len(full_images))]
        batch_region_images = torch.cat(batch_region_images)
        # with torch.no_grad():
        #     image_features = self.vlm_model.visual(full_images, full_alpha.repeat(len(full_images),1,1,1))
        #     text_features = self.vlm_model.encode_text(global_texts)

        input_images = torch.cat([full_images, batch_region_images])
        input_alphas = torch.cat([full_alpha.repeat(len(full_images), 1, 1, 1), batch_region_alphas])
        input_texts = torch.cat([global_textids, batch_local_textids])
        t7 = time.time()
        print("prepare vl input", t7 - t6)

        with torch.no_grad():
            image_features = self.vlm_model.visual(input_images, input_alphas)
            all_text_features = self.vlm_model.encode_text(input_texts)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            all_text_features = all_text_features / all_text_features.norm(dim=-1, keepdim=True)
        text_features = all_text_features[:len(global_textids) + len(batch_local_textids)]

        t8 = time.time()
        print("vlm forward", t8 - t7)

        ## print the result
        IT_similarity_matrices = 2.5 * image_features @ text_features.T
        ## print the result
        global_scores = 2.5 * torch.diag(
            IT_similarity_matrices[:batch_size, :batch_size]).detach().cpu().numpy().astype(np.float64)

        local_IT_similarity_matrices = [IT_similarity_matrices[(batch_size + np.sum(batch_local_regionnums[:idx + 1]) -
                                                                batch_local_regionnums[idx]):batch_size + np.sum(
            batch_local_regionnums[:idx + 1]), (batch_size + np.sum(batch_local_textnums[:idx + 1]) -
                                                batch_local_textnums[idx]):batch_size + np.sum(
            batch_local_textnums[:idx + 1])]
                                        for idx in range(batch_size)]
        local_precisions = np.array(
            [local_IT_similarity_matrices[idx].max(0)[0].mean().item() for idx in range(batch_size)])
        local_recalls = np.array(
            [local_IT_similarity_matrices[idx].max(1)[0].mean().item() for idx in range(batch_size)])
        local_scores = sts.hmean(np.stack([local_precisions, local_recalls]), axis=0)
        hice_scores = sts.hmean(np.stack([global_scores, local_scores]), axis=0)
        t9 = time.time()
        print("compute scores", t9 - t8)

        hice_scores_results = {"hice": hice_scores,
                               "global": global_scores,
                               "local": local_scores,
                               "local_precision": local_precisions,
                               "local_recall": local_recalls,
                               }

        return hice_scores_results

    def batch_compute_scores_with_references(self, image_paths, candidate_captions, reference_captions=None, res_dict=None, verbose=False):
        assert len(image_paths) == len(candidate_captions)
        if reference_captions is not None:
            assert len(image_paths) == len(reference_captions)
        batch_size = len(image_paths)
        # reference SG
        # t1 = time.time()
        raw_images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
        full_images = [self.preprocess(raw_image).unsqueeze(0).half().to(self.device) for raw_image in raw_images]
        full_images = torch.cat(full_images)
        # t2 = time.time()
        # print("prepare full images", t2 - t1)
        # global image-text similarity
        full_mask = torch.ones_like(full_images[0].unsqueeze(0))*1
        full_alpha = self.mask_transform(full_mask[:,0])
        full_alpha = full_alpha.half().unsqueeze(dim=0).to(self.device)
        global_textids = alpha_clip.tokenize(candidate_captions).to(self.device)
        candidate_lens = [1]*len(candidate_captions)

        ref_lens = [len(reference_captions[idx]) for idx in range(batch_size)]
        flatten_refs = [ref for ref_list in reference_captions for ref in ref_list]
        global_referenceids = alpha_clip.tokenize(flatten_refs).to(self.device)

        pred_cut_captions = [[s for s in pred.split('. ') if s!=''] for pred in candidate_captions]
        pred_cut_flatten_captions = [ref for ref_list in pred_cut_captions for ref in ref_list]
        pred_lens = [len(pred_cut_captions[idx]) for idx in range(len(pred_cut_captions))]

        all_graphs = self.sg_model.parse(flatten_refs + pred_cut_flatten_captions)
        flatten_ref_graphs = all_graphs[:len(flatten_refs)]
        pred_flatten_graphs = all_graphs[len(flatten_refs):]

        pred_start_id = 0
        ref_start_id = 0
        pred_graphs = []
        ref_graphs = []
        for pred_len, ref_len in zip(pred_lens,ref_lens):
            pred_graphs.append(pred_flatten_graphs[pred_start_id:pred_start_id + pred_len])
            ref_graphs.append(flatten_ref_graphs[ref_start_id:ref_start_id + ref_len])
            ref_start_id += ref_len
            pred_start_id += pred_len

        type_dict = {}
        local_texts = [get_graph_phrases(pred_graphs[idx], type_dict) for idx in range(len(pred_graphs))]
        local_references = [get_graph_phrases(ref_graphs[idx], type_dict) for idx in range(len(ref_graphs))]
        # t4 = time.time()
        # print("parse captions", t4 - t3)
        # print("Local texts:", local_texts)
        batch_local_textids = [alpha_clip.tokenize(local_text).to(self.device) for local_text in local_texts]
        batch_local_textnums = [batch_local_textids[idx].shape[0] for idx in range(len(batch_local_textids))]
        batch_local_textids = torch.cat(batch_local_textids)

        batch_local_referenceids = [alpha_clip.tokenize(local_text).to(self.device) for local_text in local_references]
        batch_local_referencenums = [batch_local_referenceids[idx].shape[0] for idx in range(len(batch_local_referenceids))]
        batch_local_referenceids = torch.cat(batch_local_referenceids)

        # t5 = time.time()
        # print("tokenize caption", t5 - t4)
        batch_region_alphas = detect_mask(image_paths, self.mobilesamv2, self.ObjAwareModel, self.predictor, self.mask_transform, self.args, self.device,verbose=True)

        # batch_detect_boxes = detect_batch_imgs(self.detector, self.detector_transforms, image_paths)
        # # t6 = time.time()
        # # print("detect images", t6 - t5)
        # batch_region_alphas = [convert_boxes_to_alpha(batch_detect_boxes[idx],self.mask_transform,raw_images[idx].size[1],raw_images[idx].size[0]) for idx in range(len(raw_images))]
        batch_local_regionnums = [batch_region_alphas[idx].shape[0] for idx in range(len(batch_region_alphas))]
        batch_region_alphas = torch.cat(batch_region_alphas)

        batch_region_images = [full_images[idx].repeat(batch_local_regionnums[idx],1,1,1) for idx in range(len(full_images))]

        batch_region_images = torch.cat(batch_region_images)
        # with torch.no_grad():
        #     image_features = self.vlm_model.visual(full_images, full_alpha.repeat(len(full_images),1,1,1))
        #     text_features = self.vlm_model.encode_text(global_texts)

        input_images = torch.cat([full_images, batch_region_images])
        input_alphas = torch.cat([full_alpha.repeat(len(full_images),1,1,1), batch_region_alphas])
        input_texts = torch.cat([global_textids, batch_local_textids, global_referenceids, batch_local_referenceids])
        # t7 = time.time()
        # print("prepare vl input", t7 - t6)

        with torch.no_grad():
            image_features = self.vlm_model.visual(input_images, input_alphas)
            all_text_features = self.vlm_model.encode_text(input_texts)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            all_text_features = all_text_features / all_text_features.norm(dim=-1, keepdim=True)
        text_features, reference_features = all_text_features.split((len(global_textids)+len(batch_local_textids), len(global_referenceids) + len(batch_local_referenceids)))

        # t8 = time.time()
        # print("vlm forward", t8 - t7)

        IT_similarity_matrices = 2.5 * image_features @ text_features.T
        TT_similarity_matrices = reference_features @ text_features.T

        ## global similarity
        global_IT_scores = torch.diag(IT_similarity_matrices[:batch_size,:batch_size]).detach().cpu().numpy().astype(np.float64)
        global_flatten_TT_similarity = TT_similarity_matrices[:np.sum(ref_lens),:np.sum(candidate_lens)]
        global_TT_scores = np.array([global_flatten_TT_similarity[(np.sum(ref_lens[:idx+1])-ref_lens[idx]):np.sum(ref_lens[:idx+1]),
                                        (np.sum(candidate_lens[:idx+1])-candidate_lens[idx]):np.sum(candidate_lens[:idx+1])].mean().item()
                        for idx in range(batch_size)])

        local_IT_similarity_matrices = []
        local_TT_similarity_matrices = []
        for idx in range(batch_size):
            # IT matrices
            IT_x_indices = torch.tensor([idx] + list(range(batch_size+np.sum(batch_local_regionnums[:idx+1])-batch_local_regionnums[idx],batch_size+np.sum(batch_local_regionnums[:idx+1])))).unsqueeze(1).repeat(1, 1+batch_local_textnums[idx])
            IT_y_indices = torch.tensor([idx] + list(range(batch_size+np.sum(batch_local_textnums[:idx+1])-batch_local_textnums[idx],batch_size+np.sum(batch_local_textnums[:idx+1])))).unsqueeze(0).repeat(1+batch_local_regionnums[idx], 1)
            local_IT_similarity_matrices.append(IT_similarity_matrices[IT_x_indices,IT_y_indices])
            # TT matrices
            TT_x_indices = torch.tensor(list(range(np.sum(ref_lens)+np.sum(batch_local_referencenums[:idx+1])-batch_local_referencenums[idx],np.sum(ref_lens)+np.sum(batch_local_referencenums[:idx+1])))).unsqueeze(1).repeat(1, batch_local_textnums[idx])
            TT_y_indices = torch.tensor(list(range(batch_size+np.sum(batch_local_textnums[:idx+1])-batch_local_textnums[idx],batch_size+np.sum(batch_local_textnums[:idx+1])))).unsqueeze(0).repeat(batch_local_referencenums[idx], 1)
            local_TT_similarity_matrices.append(TT_similarity_matrices[TT_x_indices,TT_y_indices])

        non_saliency_indices = [local_IT_similarity_matrices[idx].mean(0) != local_IT_similarity_matrices[idx].mean(0)[0] for idx in range(batch_size)]
        local_similarity_remove_saliency = [local_IT_similarity_matrices[idx][:,1:] for idx in range(batch_size)]
        local_IT_precisions = np.array([local_IT_similarity_matrices[idx].max(0)[0].mean().item() for idx in range(batch_size)])
        try:
            local_IT_recalls = np.array([local_similarity_remove_saliency[idx].max(1)[0].mean().item() for idx in range(batch_size)])
        except:
            print(0)
        local_IT_scores = sts.hmean(np.stack([np.maximum(local_IT_precisions,0),np.maximum(local_IT_recalls,0)]), axis=0)

        local_TT_precisions = np.array([local_TT_similarity_matrices[idx].max(0)[0].mean().item() for idx in range(batch_size)])
        local_TT_recalls = np.array([local_TT_similarity_matrices[idx].max(1)[0].mean().item() for idx in range(batch_size)])
        local_TT_scores = sts.hmean(np.stack([np.maximum(local_TT_precisions,0),np.maximum(local_TT_recalls,0)]), axis=0)

        hice_IT_scores = sts.hmean(np.stack([np.maximum(global_IT_scores,0),np.maximum(local_IT_scores,0)]), axis=0)
        hice_TT_scores = sts.hmean(np.stack([np.maximum(global_TT_scores,0), np.maximum(local_TT_scores,0)]), axis=0)
        refhice_scores = sts.hmean(np.stack([hice_IT_scores, hice_TT_scores]), axis=0)
        refhice_global_scores = sts.hmean(np.stack([np.maximum(global_IT_scores,0), np.maximum(global_TT_scores,0)]), axis=0)
        refhice_local_scores = sts.hmean(np.stack([np.maximum(local_IT_scores,0), np.maximum(local_TT_scores,0)]), axis=0)
        # t9 = time.time()
        # print("compute scores", t9 - t8)

        hice_scores_results = {
                       "hice_IT":hice_IT_scores,
                       "global_IT": global_IT_scores,
                       "local_IT": local_IT_scores,
                       "local_IT_precision": local_IT_precisions,
                       "local_IT_recall": local_IT_recalls,
                       "hice_TT": hice_TT_scores,
                       "global_TT": global_TT_scores,
                       "local_TT": local_TT_scores,
                       "local_TT_precision": local_TT_precisions,
                       "local_TT_recall": local_TT_recalls,
                       "refhice": refhice_scores,
                       "refglobal": refhice_global_scores,
                       "reflocal": refhice_local_scores,
                       }

        return hice_scores_results



