import argparse
import clip
import torch
from collections import Counter

Height, Width = 224, 224

import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
import tqdm
import gem


from data.dataset_refer_bert import ReferDataset
from model.backbone import clip_backbone, CLIPViTFM
from src.utils import extract_rela_word, relation_boxes, extract_noun_phrase, extract_nouns, gen_dir_mask, extract_dir_phrase, default_argument_parser, setup, Compute_IoU


import warnings
warnings.filterwarnings("ignore")


def get_tensor_coordinates(index, size=7):
    """Convert a 1D index to 2D coordinates in a size x size tensor."""
    row_col = divmod(index, size)
    return row_col

def is_consecutive(coord1, coord2):
    """Check if two coordinates are consecutive in a row, column, or diagonal."""
    return abs(coord1[0] - coord2[0]) <= 1 and abs(coord1[1] - coord2[1]) <= 1

def custom_clustering(index_pairs, scores, threshold=0.95, size=14):
    """Perform custom clustering based on the given criteria."""
    clusters = {}
    cluster_id = 1
    prev_indices = []
    
    for (idx1, idx2), score_values in zip(index_pairs, scores):
        if score_values[1] <= threshold:
            break  
        
        if idx1 not in clusters:
            assigned_cluster = None
            for prev_idx in prev_indices:
                if is_consecutive(get_tensor_coordinates(prev_idx, size), get_tensor_coordinates(idx1, size)):
                    assigned_cluster = clusters[prev_idx]
                    break
            if assigned_cluster is not None:
                clusters[idx1] = assigned_cluster  # Assign cluster of the first consecutive previous index
            else:
                clusters[idx1] = cluster_id  # Assign new cluster
                cluster_id += 1
                
        prev_indices.append(idx1)
    
    # Post-process to merge isolated clusters if needed
    cluster_counts = Counter(clusters.values())
    unique_clusters = {cluster for cluster, count in cluster_counts.items() if count == 1}
    
    if len(unique_clusters) > 0:
        result = [key for key, value in clusters.items() if value in unique_clusters]
        for idx in result:
            idx_coord = get_tensor_coordinates(idx, size)
            for other_idx, other_cluster in clusters.items():
                if idx != other_idx and is_consecutive(idx_coord, get_tensor_coordinates(other_idx, size)):
                    clusters[idx] = clusters[other_idx]
                    break
    return clusters


def masks_to_boxes(masks):
    """
    Convert masks to bounding boxes. Assumes masks are binary (0 or 1).
    Args:
        masks (Tensor): [n, h, w] binary mask
    Returns:
        boxes (Tensor): [n, 4] in (x_min, y_min, x_max, y_max) format
    """
    n, h, w = masks.shape
    boxes = torch.zeros((n, 4), dtype=torch.float32)

    for i in range(n):
        mask = masks[i]
        y, x = torch.where(mask != 0)

        if len(x) == 0 or len(y) == 0:
            # If the mask is empty, keep box as zeros
            continue

        x_min = x.min().item()
        x_max = x.max().item()
        y_min = y.min().item()
        y_max = y.max().item()

        boxes[i] = torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)

    return boxes
    
                    
def main(args, Height, Width):
    assert args.eval_only, 'Only eval_only available!'
    cfg = setup(args)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    preprocess = gem.get_gem_img_transform()
    dataset = ReferDataset(args,
                           image_transforms=None,
                           target_transforms=None,
                           preprocessor = preprocess,
                           split=args.split,
                           eval_mode=True)

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)
        
    from detectron2.config import get_cfg
    from detectron2.data.detection_utils import read_image
    from detectron2.projects.deeplab import add_deeplab_config
    from detectron2.utils.logger import setup_logger
    import sys
    from mask2former import add_maskformer2_config
    from predictor import VisualizationDemo

    def setup_cfg(args):
        # load config from file and command-line arguments
        cfg = get_cfg()
        add_deeplab_config(cfg)
        add_maskformer2_config(cfg)
        cfg.merge_from_file(args.config_file)
        cfg.merge_from_list(args.opts)
        cfg.freeze()
        return cfg

    class get_parser:
        config_file = "maskformer2_swin_large_IN21k_384_bs16_100ep.yaml"
        opts = []
    
    new_args = get_parser()
    cfg = setup_cfg(new_args)
    demo = VisualizationDemo(cfg)

    mode = 'ViT'  # or ViT
    assert (mode == 'Res') or (mode == 'ViT'), 'Specify mode(Res or ViT)'

    Model = clip_backbone(model_name='RN50').to(device) if mode == 'Res' else CLIPViTFM(model_name='ViT-B/16').to(device)
    Model = Model.float()

    # nlp = spacy.load('en_core_web_lg')
    import stanza
    stanza.download('en')
    nlp = stanza.Pipeline('en')

    cum_I_ref, cum_U_ref =0, 0
    m_IoU_ref = []

    cum_I, cum_U =0, 0
    m_IoU = []

    real_cum_I, real_cum_U =0, 0
    real_m_IoU = []

    ref_real_cum_I, ref_real_cum_U =0, 0
    ref_real_m_IoU = []

    max_idx = 4
    v = 0.85 if args.dataset == 'refcocog' else 0.95
    r = 0.5

    tbar = tqdm.tqdm(data_loader)

    created_mask_num_lst = []
    our_mask_num_lst = []

    acc_lst = []
    gt_max_idx_lst = []

    threshold_lst = []

    softmax0 = torch.nn.Softmax(0)

    dim = 14 #14 for ViT-B/16 # 7 for ViT-B/32

    layer = 8 # args.layer
    delta = 0.3 #args.delta 
    alpha = 0.7 #args.alpha
    top_k = args.top_k

    gem_model = gem.create_gem_model(
        model_name='ViT-B/16', pretrained='openai', device=device
    )
    
    if args.ten_percent:
        data_num = int(len(data_loader)*0.1)
    for i, data in enumerate(tbar):
        if args.ten_percent:
            break
        
        image, target, clip_embedding, sentence_raw = data
        clip_embedding, target = clip_embedding.squeeze(1), target.to(device)

        file_name = image[0]['file_name'][0]
        r_img = read_image(file_name, format="BGR") 
        _, _, pred_masks = demo.run_on_image(r_img)
        pred_boxes = masks_to_boxes(pred_masks)  # Shape: (N, 4)

        if len(pred_masks) == 0:
            print('No pred masks')
            continue

        original_imgs = torch.stack([T.Resize((height, width))(img.to(pred_masks.device)) for img, height, width in
                                     zip(image[0]['image'], image[0]['height'], image[0]['width'])], dim=0)  # [1, 3, 428, 640]
        resized_imgs = torch.stack([T.Resize((Height, Width))(img.to(pred_masks.device)) for img in image[0]['image']], dim=0)  # [1,3,224,224]

        global_imgs = []
        local_imgs = []
        cropped_imgs = []

        pixel_mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1).to(pred_masks.device)

        for pred_box, pred_mask in zip(pred_boxes.__iter__(), pred_masks):
            pred_mask, pred_box = pred_mask.type(torch.uint8), pred_box.type(torch.int)
            
            masked_image = original_imgs * pred_mask[None, None, ...] + (1 - pred_mask[None, None, ...]) * pixel_mean
    
            x1, y1, x2, y2 = int(pred_box[0]), int(pred_box[1]), int(pred_box[2]), int(pred_box[3])
            if x1 == x2:
                x2 = masked_image.shape[-1]
            if y1 == y2:
                y2 = masked_image.shape[-2]
            masked_image = TF.resized_crop(masked_image.squeeze(0), y1, x1, (y2 - y1), (x2 - x1), (Height, Width))
            cropped_imgs.append(masked_image)

        cropped_imgs = torch.stack(cropped_imgs, dim=0)

        ####### calculate hybrid features #######
        with torch.no_grad():
            mask_features = Model.feature_map_masking(resized_imgs, pred_masks) if mode == 'Res' else Model(resized_imgs, pred_masks, masking_type='token_masking', masking_block=9)
    
            crop_features = Model.get_gloval_vector(cropped_imgs) if mode == 'Res' else Model(cropped_imgs, pred_masks=None, masking_type='crop')
    
            orig_visual_feature = v * mask_features + (1 - v) * crop_features
    
            visual_features = Model.get_gloval_vector(resized_imgs) if mode == 'Res' else Model(resized_imgs, pred_masks=None, masking_type='ours_local', masking_block=layer) #'avg') #9) # *(1, 49=7*7, 512)

        for sentence, j in zip(sentence_raw, range(clip_embedding.size(-1))):
            temp_m_IoU = []
            
            sentence = sentence[0].lower()
            doc = nlp(sentence)
            sentence_for_spacy = []

            for i, token in enumerate(doc.sentences[0].words):
                if token.text == ' ':
                    continue
                sentence_for_spacy.append(token.text)

            orig_sentence_for_spacy = ' '.join(sentence_for_spacy)
            dirflag = extract_dir_phrase(orig_sentence_for_spacy, nlp, False)
            sentence_token = clip.tokenize(orig_sentence_for_spacy).to(device)
            
            _, noun_phrase, _, _ = extract_noun_phrase(orig_sentence_for_spacy, nlp, need_index=True)
        
            noun_phrase_token = clip.tokenize(noun_phrase).to(device)
            noun_phrase_features = Model.get_text_feature(noun_phrase_token) if mode == 'Res' else Model.model.encode_text(noun_phrase_token)
            
            sentence_features = Model.get_text_feature(sentence_token) if mode == 'Res' else Model.model.encode_text(sentence_token)

            text_ensemble = r * sentence_features + (1-r) * noun_phrase_features
            ####

            orig_similarity = Model.calculate_similarity_score(orig_visual_feature, text_ensemble) if mode == 'Res' else Model.calculate_score(orig_visual_feature, text_ensemble)
            orig_max_index = torch.argmax(orig_similarity)

            ####
            other_noun_phrases, nouns = extract_nouns(orig_sentence_for_spacy, nlp)
            other_noun_features = torch.zeros(1, 512).to(device)
            cnt_other_nouns = 0
            for other_noun in other_noun_phrases:
                noun_token = clip.tokenize('a photo of '+other_noun).to(device)
                other_noun_features += Model.model.encode_text(noun_token)
                cnt_other_nouns += 1
            if cnt_other_nouns != 0:
                other_noun_features = other_noun_features / cnt_other_nouns
                
            orig_similarity_Neg = Model.calculate_score(orig_visual_feature, other_noun_features)
            ####
            
            orig_similarity = orig_similarity.squeeze()
            if len(orig_similarity.shape) < 1:
                orig_similarity = orig_similarity.unsqueeze(0)

            orig_similarity_Neg = orig_similarity_Neg.squeeze()
            if len(orig_similarity_Neg.shape) < 1:
                orig_similarity_Neg = orig_similarity_Neg.unsqueeze(0)

            ############
            similarity = Model.calculate_similarity_score(visual_features, text_ensemble) if mode == 'Res' else Model.calculate_score(visual_features, text_ensemble) # [1, 49, seq]
            similarity = similarity.squeeze(0)

            chosen_sim = similarity[:, 0].reshape(dim, dim) # [49]
    
            width, height = pred_masks[0].shape

            # Flatten the tensor and get sorted indices in descending order
            sorted_indices = torch.argsort(chosen_sim.flatten(), descending=True)
            
            # Convert flat indices to 2D indices
            rows, cols = torch.div(sorted_indices, chosen_sim.shape[0], rounding_mode='floor'), sorted_indices % chosen_sim.shape[0]
            indices_2d = torch.stack((rows, cols), dim=1) # (49, 2)

            test_index_pairs = []
            test_scores = []
            for idx in range(indices_2d.shape[0]-1):
                i, j, i_next, j_next = indices_2d[idx][0].item(), indices_2d[idx][1].item(), indices_2d[idx+1][0].item(), indices_2d[idx+1][1].item()
                test_index_pairs.append((dim*i+j, dim*(i_next)+(j_next)))
                
                sim_ij = F.cosine_similarity(visual_features[0][dim*i+j], visual_features[0][dim*(i_next)+(j_next)], dim=0)
                text_sim_ij = chosen_sim[i,j] / 100
                text_sim_ij_next = chosen_sim[i_next,j_next] / 100
                test_scores.append((sim_ij, text_sim_ij, text_sim_ij_next))

            # Find cluster dict until it becomes more than 1
            cluster_dict = {}
            threshold = delta 
            while len(cluster_dict) <= 0:
                threshold -= 0.05
                cluster_dict = custom_clustering(test_index_pairs, test_scores, threshold, dim)
            threshold_lst.append(threshold)

            assert len(cluster_dict) > 0
            # Initialize a 7x7 tensor (matrix)
            tensor_shape = (dim, dim)
            tensor = torch.zeros(tensor_shape, dtype=int)
            
            # Fill the tensor with the cluster number at the corresponding indices
            for idx, cluster in cluster_dict.items():
                row, col = divmod(idx, tensor_shape[1])
                tensor[row, col] = cluster
            
            # Create binary masks for each cluster
            for i,cluster in enumerate(set(cluster_dict.values())):
                mask = (tensor == cluster).int().to(chosen_sim.device)
                mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),size=(width,height), mode='bilinear', align_corners=True)[0]
                if i == 0:
                    stacked_mask = mask
                else:
                    stacked_mask = torch.cat((stacked_mask, mask), 0)
            stacked_mask = stacked_mask.to(pred_masks.device)[:top_k]
      
            best_indices = [orig_max_index.item()]
            best_IoUs = [torch.tensor(1.)] 
            for j in range(stacked_mask.shape[0]):
                temp_m_IoU2 = []
                crit_scores = []
                for i in range(pred_masks.shape[0]):
                    cur_IoU, _, _, _ = Compute_IoU(stacked_mask[j], pred_masks[i], 0, 0, [])
                    temp_m_IoU2.append(cur_IoU)
                    den = pred_masks[i].sum()
                    num = (pred_masks[i]*stacked_mask[j]).sum()
                    crit_scores.append(1 + num/den + stacked_mask[j].sum()/stacked_mask.sum())
                    
                max_idx_per_mask = temp_m_IoU2.index(max(temp_m_IoU2))
                best_indices.append(max_idx_per_mask) 
                best_IoUs.append(crit_scores[max_idx_per_mask])

            best_IoUs = [it.item() for it in best_IoUs]
            temp_real_idx = best_IoUs.index(max(best_IoUs))
            partial_pred_masks = pred_masks[torch.tensor(best_indices)] # torch.Size([3, 427, 640])
            partial_pred_boxes = pred_boxes[torch.tensor(best_indices)] 
            created_mask_num_lst.append(partial_pred_masks.shape[0])

            concat_partial_pred_masks = partial_pred_masks 
            our_mask_num_lst.append(partial_pred_masks.shape[0])
            ###

            real_temp_m_IoU = []
            real_save_cum_I = []
            real_save_cum_U = []
            for best_idx in range(concat_partial_pred_masks.shape[0]):
                temp_cum_I, temp_cum_U = 0, 0
                iou, real_temp_m_IoU, temp_cum_I, temp_cum_U = Compute_IoU(concat_partial_pred_masks[best_idx], target, temp_cum_I, temp_cum_U, real_temp_m_IoU)

                real_save_cum_I.append(temp_cum_I)
                real_save_cum_U.append(temp_cum_U)
                
            temp_real_temp_m_IoU = [r.cpu().item() for r in real_temp_m_IoU]
            sorted_indices = np.argsort(temp_real_temp_m_IoU)
            
            real_temp_m_IoU_s, gt_max_idx = max(real_temp_m_IoU), real_temp_m_IoU.index(max(real_temp_m_IoU))
         
            real_cum_I += real_save_cum_I[gt_max_idx]
            real_cum_U += real_save_cum_U[gt_max_idx]
        
            real_m_IoU.append(real_temp_m_IoU_s)
            gt_max_idx_lst.append(gt_max_idx)

            # Spatial Relationship Guidance 
            relaflag = extract_rela_word(orig_sentence_for_spacy, nlp)
            score_clip = softmax0(orig_similarity)
            score_clip_Neg = softmax0(orig_similarity_Neg)
            k1 = min(top_k, score_clip.shape[0]) 
            k2 = min(6, score_clip_Neg.shape[0])
            
            k1 = min(k1, concat_partial_pred_masks.shape[0])
            maxidxs = best_indices[:k1] # Ours
            
            _, maxNegidxs = torch.topk(score_clip_Neg.view(-1),k=k2)
            
            topscores = np.zeros(k1) 
            boxes = pred_boxes 
            # reranking the score
            if len(nouns)==0:
                for i,idx_i in enumerate(maxidxs): 
                    for idx_j in maxidxs:   
                        s = relation_boxes(boxes[idx_i], 
                                           boxes[idx_j], 
                                           score_clip[idx_i], 
                                           score_clip[idx_j],
                                           relaflag)
                        topscores[i]=topscores[i] + s
            else:
                for i,idx_i in enumerate(maxidxs): 
                    for idx_j in maxNegidxs:   
                        s = relation_boxes(boxes[idx_i], 
                                           boxes[idx_j], 
                                           score_clip[idx_i],
                                           score_clip_Neg[idx_j],
                                           relaflag)
                        topscores[i]=topscores[i] + s
                        
            topscores = torch.Tensor(topscores).to(device)
            topscores = softmax0(topscores)
            our_final_idx = maxidxs[torch.argmax(topscores)] 

            # ########  Spatial Coherence Guidance ########
            score_gem_list=[]

            imgattn = gem_model(image[0]['tensor_img'].to(device), [noun_phrase])[0]
            imgattn = T.Resize((int(image[0]['height']), int(image[0]['width'])))(imgattn)[0]
            imgattn = imgattn.to(device)

            imgattn = (imgattn-imgattn.min()) / (imgattn.max()-imgattn.min())

            pmask = gen_dir_mask(dirflag, imgattn.shape[0], imgattn.shape[1], imgattn.device)
            imgattn = imgattn * pmask    # Spatial Position Guidance

            imgattn = imgattn / imgattn.mean()
            
            if relaflag == "big":
                black = 1.95
            elif relaflag == "small":
                black = 1.5
            else:
                black = 1.8
                
            for pred_mask in pred_masks:
                pred_mask = pred_mask.type(torch.uint8)
                score_gemtmp = (imgattn * (2-black) * pred_mask/(pred_mask.sum())).sum() - (imgattn * black * (1 - pred_mask) / ((1 - pred_mask).sum())).sum()
                score_gem_list.append(torch.Tensor([score_gemtmp]))
            score_gem = torch.stack(score_gem_list,dim=0)
            score_gem=score_gem.to(device)

            for idx_i in range(k1):      
                topscores[idx_i]=topscores[idx_i] * (1 - alpha) + alpha * score_gem[maxidxs[idx_i]][0]
            our_final_idx = maxidxs[torch.argmax(topscores)]

            comp_max_idx = our_final_idx 
            our_final_mask = pred_masks[comp_max_idx]  
            our_IoU, m_IoU, cum_I, cum_U = Compute_IoU(our_final_mask, target, cum_I, cum_U, m_IoU)
            
                    
    f = open('./result_log/our_method_with_free_solo_ours_pceonlyadded_16.txt', 'a')
    f.write(f'\n\n CLIP Model (ViT-B/16): {mode} layer: {layer} delta: {delta} alpha: {alpha} topk: {top_k}'
            f'\nDataset: {args.dataset} / {args.split} / {args.splitBy}'
            f'\nOverall IoU / mean IoU')

    overall = cum_I * 100.0 / cum_U
    mean_IoU = torch.mean(torch.tensor(m_IoU)) * 100.0

    f.write(f'\n{overall:.2f} / {mean_IoU:.2f}')

    real_overall = real_cum_I * 100.0 / real_cum_U
    real_mean_IoU = torch.mean(torch.tensor(real_m_IoU)) * 100.0

    f.write(f'\n{real_overall:.2f} / {real_mean_IoU:.2f}')

    f.write(f'Threshold {np.mean(threshold_lst):.2f} / {np.std(threshold_lst):.2f}')

    f.write(f'Number of selected candidates: {np.mean(created_mask_num_lst):.2f} / {np.std(created_mask_num_lst):.2f}')
    f.write(f'Number of created clusters: {np.mean(our_mask_num_lst):.2f} / {np.std(our_mask_num_lst):.2f}')



if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    opts = ['OUTPUT_DIR', 'training_dir/FreeSOLO_pl', 'MODEL.WEIGHTS', 'checkpoints/FreeSOLO_R101_30k_pl.pth']
    args.opts = opts
    print(args.opts)
    main(args, Height, Width)
