import os

from functools import partial
from pathlib import Path
from tqdm import tqdm

import numpy as np
import time

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
# pylab.rcParams['figure.figsize'] = 20, 12

import cv2
import base64
import io

import requests
from io import BytesIO
import argparse

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.engine.predictor_glip_spacy import GLIPDemo

from merge_nouns import merge_nouns
import spacy

import pdb


def main():
    parser = argparse.ArgumentParser(description="PyTorch Detection to Grounding Inference")
    parser.add_argument(
        "--config-file",
        default="configs/pretrain/glip_Swin_L.yaml",
        type=str,
        help="path to config file",
    )
    parser.add_argument(
        "--weight_file",
        default='/tmp/glip_large_model.pth',
        type=str,
        help="wget https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_large_model.pth -O /tmp/glip_large_model.pth",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER
    )
    parser.add_argument(
        "--laion2b_root_dir", 
        default='/path/to/coyo_filtered_tsvs', 
        type=str, 
        help="as description"
    )
    parser.add_argument(
        "--laion2b_subdir_start", 
        default=0, 
        type=int, 
        help="as description"
    )
    parser.add_argument(
        "--laion2b_subdir_num", 
        default=1, 
        type=int, 
        help="as description"
    )
    
    parser.add_argument(
        "--laion2b_subdir_subindex_start", 
        default=0, 
        type=int, 
        help="as description"
    )
    parser.add_argument(
        "--laion2b_subdir_subindex_num", 
        default=1, 
        type=int, 
        help="as description"
    )

    parser.add_argument(
        "--laion2b_save_root_dir", 
        default=None, 
        type=str, 
        help="as description"
    )
    
    parser.add_argument(
        "--vis_image_save_dir", 
        default='./output/vis_laion2b_2/', 
        type=str, 
        help="as description"
    )
    
    parser.add_argument(
        "--confidence_threshold", 
        default=0.65, 
        type=float, 
        help="as description"
    )
    
    parser.add_argument(
        "--visualize", 
        action='store_true', 
        help="as description"
    )
    
    args = parser.parse_args()
    
    print(f"Parameters:\n{args}")
    
    if args.vis_image_save_dir is not None:
        Path(args.vis_image_save_dir).mkdir(parents=True, exist_ok=True)
                
    cfg.local_rank = 0
    cfg.num_gpus = 1
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(["MODEL.WEIGHT", args.weight_file])
    cfg.merge_from_list(["MODEL.DEVICE", "cuda"])
    print(cfg)
    
    subdir_list = [str(i).zfill(5) for i in range(args.laion2b_subdir_start, args.laion2b_subdir_start + args.laion2b_subdir_num)]
    
    # init the model
    print("init model")
    glip_predictor = GLIPDemo(
        cfg,
        min_image_size=800,
        confidence_threshold=args.confidence_threshold,
        show_mask_heatmaps=False
    )
    nlp = spacy.load('en_core_web_trf')
    
    # LAION-2B data structure
    for subdir in subdir_list:
        subdir_path = os.path.join(args.laion2b_root_dir, subdir)
        if not os.path.exists(subdir_path):
            print(f"Note: {subdir_path} not exist")
            continue
        
        # process the save dir. if nor exist, makedir
        if args.laion2b_save_root_dir is not None:
            save_subdir_path = os.path.join(args.laion2b_save_root_dir, subdir)
            Path(save_subdir_path).mkdir(parents=True, exist_ok=True)
        
        # process each file in the given 
        subdir_path_files = os.listdir(subdir_path)
        print(f'Files scanned in this dir: {subdir_path_files}')
        
        subdir_file_start = args.laion2b_subdir_subindex_start
        subdir_file_end = args.laion2b_subdir_subindex_num + subdir_file_start
        
        subdir_files = [f"{subdir}_{str(i).zfill(5)}.tsv" for i in range(subdir_file_start, subdir_file_end)]
        subdir_files = [i for i in subdir_files if i in subdir_path_files]
        print(f'Files will be processed in this node: {subdir_files}')
        
        for filename in subdir_files:
            # determin the file
            file_path = os.path.join(subdir_path, filename)
            if not os.path.isfile(file_path):
                print(f"{file_path} is not a file")
                continue
            print(f"load tsv file from {file_path}")
            
            # determine whether this file has been produced and saved in the savedir
            # if the saved file is vaild, skip current file
            if args.laion2b_save_root_dir is not None and os.path.exists(os.path.join(save_subdir_path, filename)):
                try:
                    with open(os.path.join(save_subdir_path, filename), 'r', encoding='utf8') as t:
                        test = t.read().strip().split('\n')
                    print(f"{filename} exists in the target dir {save_subdir_path}, skip it!")
                    del test
                    # exist, skip
                    continue
                except:
                    print(f"{filename} exists in the target dir {save_subdir_path}, but can load it!")
                print(f"{filename} does not exist in the target dir {save_subdir_path}, process it!")
                # create a target file, and then write the predictions to it sample-wise
                prediction_list = []
            elif args.laion2b_save_root_dir is not None:
                print(f"{filename} does not exist in the target dir {save_subdir_path}, process it!")
                # create a target file, and then write the predictions to it sample-wise
                prediction_list = []
            else:
                prediction_list = None

            with open(file_path, 'r', encoding='utf8') as f:
                lines = f.read().strip().split('\n')

            _iter = tqdm(lines) if args.visualize else lines
            # _iter = lines[:100]
            _iter_num = len(lines)
            for i, doc_str in enumerate(_iter):
                start_time = time.time()
                item = doc_str.strip().split('\t')
                # tsv_str = "{}\t{}\t{}\t{}\t{}\n".format(meta_info["hash"], text_data, encoded_string, meta_info["width"], meta_info["height"])
                
                # print(f'Meta info: image height {item[-1]} width {item[-2]} with caption {item[1]}')
                try:
                    caption = item[1]
                    pil_img = Image.open(io.BytesIO(base64.b64decode(item[2]))).convert("RGB")
                    image = np.array(pil_img)[:, :, [2, 1, 0]]
                    image_h = pil_img.height
                    image_w = pil_img.width
                except KeyboardInterrupt:
                    raise
                except:
                    print(f"ERROR during loading the image, hash {item[0]}, image height {image_h} width {image_w}\n caption {caption}")
                    if prediction_list is not None:
                        prediction_list.append('\t'.join(item[:5])+'\n')
                    continue
                
                # pdb.set_trace()
                try:
                    prediction = glip_predictor.inference(image, caption)
                except KeyboardInterrupt:
                    raise
                except:
                    print(f"skip image due to some reasons, hash {item[0]}, image height {image_h} width {image_w}\n caption {caption}")
                    if prediction_list is not None:
                        prediction_list.append('\t'.join(item[:5])+'\n')
                    continue
                
                # pdb.set_trace()
                pre_bbox = prediction.bbox.numpy()
                pre_score = prediction.get_field('scores').numpy()
                pre_label = prediction.get_field("labels").numpy()
                
                if len(pre_label) == 0:
                    ################################ NOTE DO Not drop the img with no grounding results
                    prediction_list.append('\t'.join(item[:5])+'\n')
                    continue
                
                entity_names = glip_predictor.entities
                tokens_positive = glip_predictor.tokens_positive
                
                if cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
                    plus = 1
                else:
                    plus = 0
            
                # pre_entity_labels = [entity_names[l - plus] if l <= len(entity_names) else 'object' for l in pre_label]
                pre_entity_labels = [entity_names[l - plus]  for l in pre_label]
                pre_entity_labels = np.asarray(pre_entity_labels)
                
                # filter the box labeled 'object'
                pre_bbox, pre_score, pre_label, pre_entity_labels = nms(pre_bbox, pre_score, pre_label, pre_entity_labels)
                pre_token_pos = [tokens_positive[l - plus][0]  for l in pre_label]
                
                # turn it into a string, and append it to predictions list
                if prediction_list is not None:
                    pre_list = []
                    for _pos, _box, _score, _phrase in zip(pre_token_pos, pre_bbox, pre_score, pre_entity_labels):
                        assert _phrase == caption[_pos[0]:_pos[1]]
                        pre_list.append([_pos[0], _pos[1], _box[0]/image_w, _box[1]/image_h, _box[2]/image_w, _box[3]/image_h, _score])
                    
                    # just for match the input format of function 'merge_nouns'
                    p_item = [item[0], caption, 0, 0, 0, pre_list]
                    grounding_list = merge_nouns(nlp, p_item)
                    final_grounding_list = {'phrase': pre_list, 'expression_v1': grounding_list}
                    pre_list = str(final_grounding_list)
                    
                    # pdb.set_trace()
                    tsv_str = "{}\t{}\t{}\t{}\t{}\t{}\n".format(item[0], item[1], item[2], item[3], item[4], pre_list)
                    prediction_list.append(tsv_str)
                    
                    end_time = time.time()
                    print(f"[{str(i).zfill(4)}/{_iter_num}] {end_time - start_time:.2f} s/iter")

            if prediction_list is not None:
                print(f"Writing to {os.path.join(save_subdir_path, filename)}")
                with open(os.path.join(save_subdir_path, filename), 'w', encoding='utf8') as f:
                    f.writelines(prediction_list)
                    
def nms(boxes, scores, labels, entities, iou_threshold=0.95):
    # filter the bounding box labeled 'object'
    # pdb.set_trace()
    keep = entities != "object"
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]
    entities = entities[keep]

    # calculate the area for each bounding box
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

    # sort bounding box based on confidence
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)

        xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
        yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
        xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
        yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])

        w = np.maximum(0.0, xx2 - xx1)
        h = np.maximum(0.0, yy2 - yy1)
        inter_area = w * h

        iou = inter_area / (areas[i] + areas[order[1:]] - inter_area)

        inds = np.where(iou <= iou_threshold)[0]
        order = order[inds + 1]

    return boxes[keep], scores[keep], labels[keep], entities[keep]

def insert_bounding_boxes(text, boxes, positions):
    result = []
    last_end = 0
    for i, pos in enumerate(positions):
        start, end = pos
        result.append(text[last_end:start])
        if i > 0 and start == positions[i-1][0]:
            result[-1] += f'({boxes[i]})'
        else:
            result.append(f'[{text[start:end]}]({boxes[i]})')
        last_end = end
    result.append(text[last_end:])
    return ''.join(result)
              
if __name__ == '__main__':
    main()
