import argparse
import sys
import torch
from detectron2.config import get_cfg
from detectron2.engine import  default_setup
from freesolo import add_solo_config

MAIN_POS = ['NOUN', 'VERB', 'ADJ', 'PROPN', 'NUM']

RELATION_WORDS={"left", "west",\
                "right", "east",\
                "above", "north", "top", "back", "behind",\
                "below", "south", "under", "front",\
                "bigger", "larger", \
                "closer","smaller", "tinier", "further",\
                "inside", "within", "contained",\
                "who","what","which",\
                "middle"}

COLORS = ['red', 'blue', 'green', 'yellow', 'black', 'white', 'pink', 'orange', 'purple', 'brown', 'gray',
          'grey', 'teal', 'beige', 'maroon', 'lavender', 'turquoise', 'cyan', 'magenta', 'ivory', 'navy', 'gold',  
          'silver', 'bronze', 'dark', 'blurry', 'light']

SHAPES = [
  "triangle",
  "square",
  "pentagon",
  "hexagon",
  "heptagon",
  "octagon",
  "nonagon",
  "decagon",
  "dodecagon",
  "circle",
  "ellipse",
  "rectangle",
  "parallelogram",
  "rhombus",
  "trapezoid",
  "kite",
  "crescent",
  "sector",
  "annulus",
  "heart",
  "tetrahedron",
  "cube",
  "octahedron",
  "dodecahedron",
  "icosahedron",
  "sphere",
  "hemisphere",
  "cylinder",
  "cone",
  "pyramid",
  "prism",
  "torus",
  "ellipsoid",
  "frustum",
  "spiral",
  "helix",
  "lune",
  "lemniscate",
  "astroid",
  "cardioid",
  "hyperbola",
  "parabola"
]


def DFS_right(node, variables: list):
    if node.is_leaf() or variables[1] != '':
        return
    if node.label == 'NN':
        variables[1] = str(node.children[0])
        return
    for i in range(len(node.children)):
        idx = len(node.children) - i - 1
        DFS_right(node.children[idx], variables)

def DFS_left(node, variables: list):
    if node.is_leaf() or variables[1] != '':
        return
    if node.label == 'NP':
        variables[0] += 1
    now_find_np = variables[0]
    for child in node.children:
        DFS_left(child, variables)
    if node.label == 'NP' and variables[0] == now_find_np:
        # find rightmost NN
        DFS_right(node, variables)


def extract_noun_phrase(text, nlp, need_index=False):
    doc = nlp(text)
    variables = [0, '']
    DFS_left(doc.sentences[0].constituency, variables)
    agent = variables[1]
        
    if agent == '':
        for i in range(len(doc.sentences[0].words)):
            idx = len(doc.sentences[0].words) - i - 1
            if doc.sentences[0].words[idx].pos == 'NOUN' and doc.sentences[0].words[idx].text not in RELATION_WORDS and doc.sentences[0].words[idx].text not in ['half']:
                agent = doc.sentences[0].words[idx].text
                break
    
    if agent == '':
        agent = '[UNK]'
    agent_id = -1
    agents = []
    bef_noun_flag = True
    bef_rel_flag = True
    word_lst = []
    for word in doc.sentences[0].words:
        word_lst.append(word.text)
        if agent == word.text: 
            agent_id = word.id
            agents.append(word.text)
            bef_noun_flag = False
        else:
            if word.pos == 'NOUN' and bef_noun_flag:
                agents.append(word.text)
            if (word.text in COLORS or word.pos == 'NUM' or word.text in SHAPES) and word.text not in RELATION_WORDS: 
                agents.append(word.text)
                
    agent = ' '.join(word_lst[:agent_id])

    positive_cues_lst = []
    for word in doc.sentences[0].words:
        if word.text in RELATION_WORDS: 
            bef_rel_flag = False
        if bef_rel_flag and word.text not in agent and (word.text in COLORS or word.pos == 'NUM' or word.text in SHAPES):
            positive_cues_lst.append(word.text)

    agent = ' '.join(word_lst[:agent_id]+positive_cues_lst)

    if need_index:
        return agents, agent, agent_id, len(doc.sentences[0].words)
    else:
        return agent
    

def extract_nouns(text, nlp, need_index=False):
    doc = nlp(text)
    noun_phrases = []
    nouns = []
    nouns_index = []
    head_noun = extract_noun_phrase(text, nlp)
  
    for i in range(len(doc.sentences[0].words)):
        idx = len(doc.sentences[0].words) - i - 1
        chosen_word = doc.sentences[0].words[idx]
        if chosen_word.pos == 'NOUN':
            if chosen_word.text in head_noun or chosen_word.text in RELATION_WORDS: # HybridGL처럼 head_noun 아닌 경우만? not in 할 때도 잘 나왔음
                continue
            noun_phrases.append(chosen_word.text)
            nouns_index.append((chosen_word.start_char, chosen_word.end_char))
            nouns.append(chosen_word.text)
    
    if need_index:
        return noun_phrases, nouns_index, nouns
    else:
        return noun_phrases, nouns


def extract_dir_phrase(text, nlp, need_index=False):
    dirflag = "none"
    diridx = 999
    deep2head = 999
    doc = nlp(text)
    for i in range(len(doc.sentences[0].words)):
        token = doc.sentences[0].words[i]
        if token.text == "left" and token.head<deep2head:
            dirflag = "left"
            diridx = token.head
            deep2head = token.head
        elif token.text == "right" and token.head<deep2head:
            dirflag = "right"
            diridx = token.head
            deep2head = token.head
        elif token.text in {"middle","between"} and token.head<deep2head:
            dirflag = "middle"
            diridx = token.head
            deep2head = token.head
        elif token.text in {"up","top","above"} and token.head<deep2head:
            dirflag = "up"
            diridx = token.head
            deep2head = token.head
        elif token.text in {"down","under","bottom","low"} and token.head<deep2head:
            dirflag = "down"
            diridx = token.head
            deep2head = token.head

    if need_index:
        return dirflag, diridx
    else:
        return dirflag
        

NULL_KEYWORDS = {"part", "image", "side", "picture", "half", "region", "section", "photo"}
LEFT_KEYWORDS = {"left", "west"}
RIGHT_KEYWORDS = {"right", "east"}
UP_KEYWORDS = {"above", "north", "top", "back", "behind"}
DOWN_KEYWORDS = {"below", "south", "under", "front"}
BIG_KEYWORDS = {"bigger", "larger", "closer"} 
SMALL_KEYWORDS = {"smaller", "tinier", "further", "smallest"}
WITHIN_KEYWORDS = {"inside", "within", "contained", "between"}


def gen_dir_mask(dirflag,height,width,device):
    if dirflag=="left":
        a=torch.linspace(1,0,width)
        pmask=a.expand(height,width)
    elif dirflag=="right":
        b=torch.linspace(0,1,width)
        pmask=b.expand(height,width)
    elif dirflag=="middle":
        b1=torch.linspace(0,1,width//2)
        b2=torch.linspace(1,0,width-width//2)
        b = torch.cat([b1,b2])
        pmask=b.expand(height,width)
    else:
        pmask=torch.ones(height,width)
    
    if device:
        return pmask.to(device)
    else:
        return pmask
      

def extract_rela_word(text, nlp):
    noun_phrases, nouns = extract_nouns(text, nlp)
    if (set(nouns) & NULL_KEYWORDS):
        relaflag = "none"
    else:
        relaflag = "none"
        deep2head = 999
        doc = nlp(text)
    
        for i in range(len(doc.sentences[0].words)):
            token = doc.sentences[0].words[i]
            if token.text in LEFT_KEYWORDS and token.head<deep2head:
                relaflag = "left"
                deep2head = token.head
            elif token.text == RIGHT_KEYWORDS and token.head<deep2head:
                relaflag = "right"
                deep2head = token.head
            elif token.text in UP_KEYWORDS and token.head<deep2head:
                relaflag = "up"
                deep2head = token.head
            elif token.text in DOWN_KEYWORDS and token.head<deep2head:
                relaflag = "down"
                deep2head = token.head
            elif token.text in BIG_KEYWORDS and token.head<deep2head:
                relaflag = "big"
                deep2head = token.head
            elif token.text in SMALL_KEYWORDS and token.head<deep2head:
                relaflag = "small"
                deep2head = token.head
            elif token.text in WITHIN_KEYWORDS and token.head<deep2head:
                relaflag = "within"
                deep2head = token.head
    return relaflag


def relation_boxes(boxi, boxj, scorei, scorej, relaword):
    scoreout = 0

    if relaword == "none":
        scoreout = scorei
    elif relaword == "left":
        scoreout = scorei * scorej * ((boxi[0]+boxi[2]/2)<(boxj[0]+boxj[2]/2))
    elif relaword == "right":
        scoreout = scorei * scorej * ((boxi[0]+boxi[2]/2)>(boxj[0]+boxj[2]/2))
    elif relaword == "up":        
        scoreout = scorei * scorej * ((boxi[1]+boxi[3]/2)<(boxj[1]+boxj[3]/2))
    elif relaword == "down":        
        scoreout = scorei * scorej * ((boxi[1]+boxi[3]/2)>(boxj[1]+boxj[3]/2))
    elif relaword == "big":        
        scoreout = scorei * scorej * ((boxi[2]*boxi[3])>(boxj[2]*boxj[3]))
        # scoreout = scorei * scorej * ((boxi[2]*boxi[3])/(boxj[2]*boxj[3]))
    elif relaword == "small":        
        scoreout = scorei * scorej * ((boxi[2]*boxi[3])<(boxj[2]*boxj[3]))
        # scoreout = scorei * scorej * (boxj[2]*boxj[3])/((boxi[2]*boxi[3]))
    elif relaword == "within":        
        x1 = max(boxi[0], boxj[0])
        x2 = max(x1, min(boxi[0]+boxi[2], boxj[0]+boxj[2]))
        y1 = max(boxi[1], boxj[1])
        y2 = max(y1, min(boxi[1]+boxi[3], boxj[1]+boxj[3]))
        scoreout = scorei * scorej * (x2-x1) * (y2-y1) / (boxi[2]*boxi[3])   
    else :        
        scoreout = scorei

    return scoreout


def default_argument_parser(epilog=None):
    """
    Create a parser with some common arguments used by detectron2 users.

    Args:
        epilog (str): epilog passed to ArgumentParser describing the usage.

    Returns:
        argparse.ArgumentParser:
    """
    parser = argparse.ArgumentParser(
        epilog=epilog
        or f"""
Examples:

Run on single machine:
    $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml

Change some config options:
    $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001

Run on multiple machines:
    (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
    (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
""",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("--config-file", default="configs/freesolo/freesolo_30k.yaml", metavar="FILE", help="path to config file")
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Whether to attempt to resume from the checkpoint directory. "
        "See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
    )
    parser.add_argument("--eval-only", action="store_false", help="perform evaluation only")
    parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
    parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
    parser.add_argument(
        "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
    )

    parser.add_argument("--delta", type=float, default=0.5, help="initial threshold")
    parser.add_argument("--alpha", type=float, default=0.5, help="proportion of spatial coherence")
    parser.add_argument("--layer", type=int, default=10, help="exit layer")
    parser.add_argument("--top_k", type=int, default=3, help="number of top clusters")
    parser.add_argument("--ten_percent", action="store_true", help="whether to use the first 10 percent of the data for the hyperparameter search")

    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
    parser.add_argument(
        "opts",
        help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
        """.strip(),
        default=None,
        nargs=argparse.REMAINDER,
    )

    parser.add_argument('--clip_model', default='RN50', help='CLIP model name', choices=['RN50', 'RN101', 'RN50x4', 'RN50x64',
                                                                                          'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT0L/14@336px'])
    parser.add_argument('--visual_proj_path', default='./pretrain/', help='')
    parser.add_argument('--dataset', default='refcocog', help='refcoco, refcoco+, or refcocog')
    parser.add_argument('--split', default='val', help='only used when testing, testA, testB')
    parser.add_argument('--splitBy', default='umd', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
    parser.add_argument('--unseen', action='store_true', help='Whether to test unseen mode')
    parser.add_argument('--seen', action='store_true', help='Whether to test seen mode')
    parser.add_argument('--img_size', default=480, type=int, help='input image size')
    parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
    parser.add_argument('--show_results', action='store_true', help='Whether to show results ')

    return parser


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_solo_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def Compute_IoU(pred, target, cum_I, cum_U, mean_IoU=[]):

    if target.dtype != torch.bool:
        target = target.type(torch.bool).squeeze(0)

    I = torch.sum(torch.logical_and(pred, target))
    U = torch.sum(torch.logical_or(pred, target))

    if U == 0:
        this_iou = 0.0
    else:
        this_iou = I * 1.0 / U
    I, U = I, U


    cum_I += I
    cum_U += U
    mean_IoU.append(this_iou)

    return this_iou, mean_IoU, cum_I, cum_U