
""""""

from typing import List, Tuple

from .text import STOP_WORDS, nlp


class PreprocessError(Exception):
    pass


def span_intersect_span(span1: Tuple[int, int], span2: Tuple[int, int]):
    """Returns True if the given spans intersect"""
    return (span1[0] <= span2[0] < span1[1]) or (span2[0] <= span1[0] < span2[1])


def span_intersect_spanlist(span: Tuple[int, int], target_spans: List[Tuple[int, int]]):
    """Returns True if the given spans intersect with any in the given list"""
    for t in target_spans:
        if span_intersect_span(span, t):
            return True
    return False


def spanlist_intersect_spanlist(spans: List[Tuple[int, int]], target_spans: List[Tuple[int, int]]):
    """Returns True if the given spans intersect with any in the given list"""
    for s in spans:
        if span_intersect_spanlist(s, target_spans):
            return True
    return False


def consolidate_spans(spans: List[Tuple[int, int]], caption: str, rec=True):
    """Accepts a list of spans and the the corresponding caption.
    Returns a cleaned list of spans where:
        - Overlapping spans are merged
        - It is guaranteed that spans start and end on a word
    """
    sorted_spans = sorted(spans)
    cur_end = -1
    cur_beg = None
    final_spans: List[Tuple[int, int]] = []
    for s in sorted_spans:
        if s[0] >= cur_end:
            if cur_beg is not None:
                final_spans.append((cur_beg, cur_end))
            cur_beg = s[0]
        cur_end = max(cur_end, s[1])

    if cur_beg is not None:
        final_spans.append((cur_beg, cur_end))

    
    clean_spans: List[Tuple[int, int]] = []
    for s in final_spans:
        beg, end = s
        end = min(end, len(caption))
        while beg < len(caption) and not caption[beg].isalnum():
            beg += 1
        while end > 0 and not caption[end - 1].isalnum():
            end -= 1
        
        if end < len(caption) and caption[end] == "-":
            
            next_space = caption.find(" ", end)
            if next_space == -1:
                end = len(caption)
            else:
                end = next_space + 1
        if beg > 0 and caption[beg - 1] == "-":
            prev_space = caption.rfind(" ", 0, beg)
            if prev_space == -1:
                beg = 0
            else:
                beg = prev_space + 1
        if 0 <= beg < end <= len(caption):
            clean_spans.append((beg, end))
    if rec:
        return consolidate_spans(clean_spans, caption, False)
    return clean_spans


def get_canonical_spans(orig_spans: List[List[Tuple[int, int]]], orig_caption: str, whitespace_only=False):
    """This functions computes the spans after reduction of the caption to it's normalized version
    For example, if the caption is "There is a man wearing sneakers" and the span is [(11,14)] ("man"),
    then the normalized sentence is "man wearing sneakers" so the new span is [(0,3)]
    """
    
    
    new_spans = [sorted(spans) for spans in orig_spans]
    caption = orig_caption.lower()

    def remove_chars(pos, amount):
        for i in range(len(new_spans)):
            for j in range(len(new_spans[i])):
                if pos >= new_spans[i][j][1]:
                    continue
                beg, end = new_spans[i][j]
                if span_intersect_span(new_spans[i][j], (pos, pos + amount)):
                    
                    new_spans[i][j] = (beg, end - amount)
                else:
                    new_spans[i][j] = (beg - amount, end - amount)

    def change_chars(old_beg, old_end, delta):
        for i in range(len(new_spans)):
            for j in range(len(new_spans[i])):
                if old_beg >= new_spans[i][j][1]:
                    continue
                beg, end = new_spans[i][j]
                if span_intersect_span(new_spans[i][j], (old_beg, old_end)):
                    if not (new_spans[i][j][0] <= old_beg < old_end <= new_spans[i][j][1]):
                        raise PreprocessError(f"deleted spans should be contained in known span")
                    assert (
                        new_spans[i][j][0] <= old_beg < old_end <= new_spans[i][j][1]
                    ), "deleted spans should be contained in known span"
                    new_spans[i][j] = (beg, end + delta)
                else:
                    new_spans[i][j] = (beg + delta, end + delta)

    
    
    while caption[0] == " ":
        remove_chars(0, 1)
        caption = caption[1:]
    cur_start = 0
    pos = caption.find("  ", cur_start)
    while pos != -1:
        amount = 1
        
        remove_chars(pos, amount)
        caption = caption.replace("  ", " ", 1)
        pos = caption.find("  ", cur_start)
    
    
    if whitespace_only:
        return new_spans, caption

    
    for punct in [".", ",", "!", "?", ":"]:
        pos = caption.find(punct)
        while pos != -1:
            remove_chars(pos, len(punct))
            caption = caption.replace(punct, "", 1)
            pos = caption.find(punct)
    
    

    
    all_tokens = nlp(caption)
    tokens = []

    
    
    for t in all_tokens:
        if str(t) not in STOP_WORDS:
            tokens.append(t)
    
    for stop in STOP_WORDS:
        cur_start = 0
        pos = caption.find(stop, cur_start)
        while pos != -1:
            
            if (pos == 0 or caption[pos - 1] == " ") and (
                pos + len(stop) == len(caption) or caption[pos + len(stop)] == " "
            ):
                removed = stop
                spaces = 0
                if pos + len(stop) < len(caption) and caption[pos + len(stop)] == " ":
                    removed += " "
                    spaces += 1
                if pos > 0 and caption[pos - 1] == " ":
                    removed = " " + removed
                    spaces += 1
                if spaces == 0:
                    raise PreprocessError(
                        f"No spaces found in '{caption}', position={pos}, stopword={stop}, len={len(stop)}"
                    )
                assert spaces > 0
                replaced = "" if spaces == 1 else " "
                amount = len(removed) - len(replaced)
                
                remove_chars(pos, amount)
                caption = caption.replace(removed, replaced, 1)
                
                
            else:
                cur_start += 1
            pos = caption.find(stop, cur_start)

    
    

    
    final_caption = []
    if len(tokens) != len(caption.strip().split(" ")):
        raise PreprocessError(
            f"''{tokens}'', len={len(tokens)}, {caption.strip().split(' ')}, len={len(caption.strip().split(' '))}"
        )

    
    cur_beg = 0
    for i, w in enumerate(caption.strip().split(" ")):
        if tokens[i].lemma_[0] != "-":
            
            final_caption.append(tokens[i].lemma_)
            change_chars(cur_beg, cur_beg + len(w), len(tokens[i].lemma_) - len(w))
        else:
            
            final_caption.append(w)
        cur_beg += 1 + len(final_caption[-1])
        
        

    clean_caption = " ".join(final_caption)
    
    clean_spans = []
    for spans in new_spans:
        cur = []
        for s in spans:
            if 0 <= s[0] < s[1]:
                cur.append(s)
        clean_spans.append(cur)

    
    
    return clean_spans, clean_caption


def shift_spans(spans: List[Tuple[int, int]], offset: int) -> List[Tuple[int, int]]:
    final_spans = []
    for beg, end in spans:
        final_spans.append((beg + offset, end + offset))
    return final_spans
