import spacy
import re, ast
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
pylab.rcParams['figure.figsize'] = 20, 12

import cv2
import base64
import io

import pdb


def main():
    nlp = spacy.load('en_core_web_trf')

    # example input file
    laion_file = './output/00000_00000.tsv'

    with open(laion_file, 'r', encoding='utf8') as f:
        lines = f.read().strip().split('\n')
        
    for doc_str in lines:
        json_obj = {}
        item = doc_str.strip().split('\t')

        try:
            image_h = int(item[4])
            image_w = int(item[3])
            caption = item[1]
        except:
            print(f"ERROR during loading this image")
            continue
        
        # pdb.set_trace()
        if len(item) == 5: # original version
            print(f'Meta info: image height {item[4]} width {item[3]} with caption {item[1]}')
        elif len(item) == 6:
            merge_nouns(nlp, item)
            
            
def merge_nouns(nlp, item):
    caption = item[1]
    doc = nlp(caption)
            
    # [[phrase_start, phrase_end, x1_norm, y1_norm, x2_norm, y2_norm, score], ...]
    try:
        if isinstance(item[5], str):
            grounding_list = ast.literal_eval(item[5])
        elif isinstance(item[5], list):
            grounding_list = item[5]
        else:
            print(f"Unsupport type {type(item[5])}")
    except IndexError as e:
        print(f"This length of line is {len(item)}, image hash {item[0]}, caption {caption}")
        print("Skip it")
        return None
    
    process_grounding_list = []
    
    for prediction in grounding_list:
        phrase_start, phrase_end, x1_norm, y1_norm, x2_norm, y2_norm, score = prediction
        phrase = caption[phrase_start:phrase_end]

        # find the potenial extension
        for nc in doc.noun_chunks:
            # the phrase is in the grounding_list
            if nc.text == phrase and doc[nc.start:nc.end].start_char == phrase_start and doc[nc.start:nc.end].end_char == phrase_end:                
                # add some rule here to filter some noisy
                nc_left_spans = []
                nc_right_spans = []
                
                # 1. filter the special character that could mislead spacy
                remove = False
                special_char = ['|', ':', ';', '@', '^', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', '.']
                for left_span in list(doc[nc.root.left_edge.i:nc.start])[::-1]:
                    if left_span.text in special_char:
                        break
                    elif left_span.text == '-' and (' -' in caption or '- ' in caption or ' - 'in caption):
                        break
                    nc_left_spans = [left_span,] + nc_left_spans
                for right_span in list(doc[nc.end:nc.root.right_edge.i+1]):
                    if right_span.text in special_char:
                        break
                    elif right_span.text == '-' and (' -' in caption or '- ' in caption or ' - 'in caption):
                        break
                    nc_right_spans += [right_span,]
                
                
                # filter
                if len(nc_left_spans) + len(nc_right_spans) == 0:
                    continue
                
                # 2. filter the phrase which has conjuncts
                if len(nc.root.conjuncts) != 0:
                    nc_left_spans = []
                    nc_right_spans = []
                    print(f"\n filter {nc.text} -> {doc[nc.root.left_edge.i:nc.root.right_edge.i+1]} in {caption} \n because root {nc.root.text} contains conjuncts {nc.root.conjuncts}")
                    continue
                    
                left_bound = max(nc_left_spans[0].i, nc.root.left_edge.i) if len(nc_left_spans) > 0 else nc.root.left_edge.i
                right_bound = min(nc_right_spans[-1].i, nc.root.right_edge.i) if len(nc_right_spans) > 0 else nc.root.right_edge.i
                right_bound += 1
                    
                if left_bound < nc.start or right_bound > nc.end:
                    # pdb.set_trace()
                    print('\n', caption)
                    # print(f"Original: {nc.text} \t is extended to \t {doc[nc.root.left_edge.i:nc.root.right_edge.i+1].text}")
                    print(f"{nc.text} \t is extended to \t {doc[left_bound: right_bound].text}")
                    # tokens_positive.append([[doc[text1.start:text1.end].start_char, doc[text1.start:text1.end].end_char]])
                    phrase_start = doc[left_bound: right_bound].start_char
                    phrase_end = doc[left_bound: right_bound].end_char
                    
        process_grounding_list.append([phrase_start, phrase_end, x1_norm, y1_norm, x2_norm, y2_norm, score])
    
    return process_grounding_list

if __name__ == '__main__':
    main()