import spacy
import spacy_transformers
from typing import List, Dict

start_token = "<|startoftext|>"
end_token = "<|endoftext|>"

class text_analyzer():
    def __init__(self, tokenizer):
        self.parser = spacy.load("en_core_web_trf")
        self.tokenizer = tokenizer
    
    def _align_indices(self, prompt, spacy_pairs):
        wordpieces2indices = get_indices(self.tokenizer, prompt)
        paired_indices = []
        collected_spacy_indices = (
            set()
        )  # helps track recurring nouns across different relations (i.e., cases where there is more than one instance of the same word)

        for pair in spacy_pairs:
            curr_collected_wp_indices = (
                []
            )  # helps track which nouns and amods were added to the current pair (this is useful in sentences with repeating amod on the same relation (e.g., "a red red red bear"))
            for member in pair:
                for idx, wp in wordpieces2indices.items():
                    if wp in [start_token, end_token]:
                        continue

                    wp = wp.replace("</w>", "")
                    if member.text == wp:
                        if idx not in curr_collected_wp_indices and idx not in collected_spacy_indices:
                            curr_collected_wp_indices.append(idx)
                            break
                    # take care of wordpieces that are split up
                    elif member.text.startswith(wp) and wp != member.text:  # can maybe be while loop
                        wp_indices = align_wordpieces_indices(
                            wordpieces2indices, idx, member.text
                        )
                        # check if all wp_indices are not already in collected_spacy_indices
                        if wp_indices and (wp_indices not in curr_collected_wp_indices) and all([wp_idx not in collected_spacy_indices for wp_idx in wp_indices]):
                            curr_collected_wp_indices.append(wp_indices)
                            break

            for collected_idx in curr_collected_wp_indices:
                if isinstance(collected_idx, list):
                    for idx in collected_idx:
                        collected_spacy_indices.add(idx)
                else:
                    collected_spacy_indices.add(collected_idx)

            paired_indices.append(curr_collected_wp_indices)

        return paired_indices

    def _extract_attribution_indices(self, prompt, object_indices=None):
        try:
            self.doc = self.parser(prompt)

            # extract standard attribution indices
            pairs = extract_attribution_indices(self.doc)

            # extract attribution indices with verbs in between
            pairs_2 = extract_attribution_indices_with_verb_root(self.doc)
            pairs_3 = extract_attribution_indices_with_verbs(self.doc)
            # make sure there are no duplicates
            pairs = unify_lists(pairs, pairs_2, pairs_3)

            print(f"Final pairs collected: {pairs}")
            paired_indices = self._align_indices(prompt, pairs)

            attribute_indices = []
            for i in object_indices:
                flag = True
                for jtem in paired_indices:
                    if len(jtem) != 0:
                        if i == jtem[-1]:
                            attribute_indices.append(jtem[:-1])
                            flag = False
                if flag:
                    attribute_indices.append([])
                            
            # flatten attribute indices
            flattened_attribute_indices = []
            for indices in attribute_indices:
                current_indices = []
                for idx in indices:
                    if isinstance(idx, list):
                        current_indices.extend(idx)
                    else:
                        current_indices.append(idx)
                flattened_attribute_indices.append(current_indices)
            return flattened_attribute_indices
        except:
            return [[] for i in object_indices]

    def _extract_objects_indices(self, prompt):
        self.doc = self.parser(prompt)
        print(self.doc)
        return extract_objects_indices(self.doc)

def unify_lists(lists_1, lists_2, lists_3):
    unified_list = lists_1 + lists_2 + lists_3
    sorted_list = sorted(unified_list, key=len)
    seen = set()

    result = []

    for i in range(len(sorted_list)):
        if tuple(sorted_list[i]) in seen:  # Skip if already added
            continue

        sublist_to_add = True
        for j in range(i + 1, len(sorted_list)):
            if is_sublist(sorted_list[i], sorted_list[j]):
                sublist_to_add = False
                break

        if sublist_to_add:
            result.append(sorted_list[i])
            seen.add(tuple(sorted_list[i]))

    return result

def is_sublist(sub, main):
    # This function checks if 'sub' is a sublist of 'main'
    return len(sub) < len(main) and all(item in main for item in sub)

def align_wordpieces_indices(
        wordpieces2indices, start_idx, target_word
):
    """
    Aligns a `target_word` that contains more than one wordpiece (the first wordpiece is `start_idx`)
    """

    wp_indices = [start_idx]
    wp = wordpieces2indices[start_idx].replace("</w>", "")

    # Run over the next wordpieces in the sequence (which is why we use +1)
    for wp_idx in range(start_idx + 1, len(wordpieces2indices)):
        if wp == target_word:
            break

        wp2 = wordpieces2indices[wp_idx].replace("</w>", "")
        if target_word.startswith(wp + wp2) and wp2 != target_word:
            wp += wordpieces2indices[wp_idx].replace("</w>", "")
            wp_indices.append(wp_idx)
        else:
            wp_indices = (
                []
            )  # if there's no match, you want to clear the list and finish
            break

    return wp_indices

def extract_objects_indices(doc):
    objects_indices = []
    object_counts = []
    for i, w in enumerate(doc):
        if w.pos_ in ["NOUN"]:
            count = 1
            for child in w.children:
                if child.pos_ in ["NUM"]:
                    count = str2num(child.__str__())
            objects_indices.append(i+1)
            object_counts.append(count)
    return objects_indices, object_counts

def extract_attribution_indices(doc):
    # doc = parser(prompt)
    subtrees = []
    modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp", "nummod"]

    for w in doc:
        if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
            continue
        subtree = []
        stack = []
        # print(w)
        for child in w.children:
            # print("\t", child, child.dep_)
            if child.dep_ in modifiers or child.dep_ in ["acl"]: # adding verb to the attributes' list # "nummod"
                subtree.append(child)
                stack.extend(child.children)

        while stack:
            node = stack.pop()
            if node.dep_ in modifiers or node.dep_ == "conj":
                subtree.append(node)
                stack.extend(node.children)
        if subtree:
            subtree.append(w)
            subtrees.append(subtree)
    
    print("subtrees", subtrees)
    return subtrees

def extract_attribution_indices_with_verbs(doc):
    '''This function specifically addresses cases where a verb is between
       a noun and its modifier. For instance: "a dog that is red"
       here, the aux is between 'dog' and 'red'. '''

    subtrees = []
    modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp", "nummod",
                 'relcl']
    for w in doc:
        if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
            continue
        subtree = []
        stack = []
        for child in w.children:
            if child.dep_ in modifiers:
                if child.pos_ not in ['AUX', 'VERB']:
                    subtree.append(child)
                stack.extend(child.children)

        while stack:
            node = stack.pop()
            if node.dep_ in modifiers or node.dep_ == "conj":
                # we don't want to add 'is' or other verbs to the loss, we want their children
                if node.pos_ not in ['AUX', 'VERB']:
                    subtree.append(node)
                stack.extend(node.children)
        if subtree:
            subtree.append(w)
            subtrees.append(subtree)
        return subtrees

def extract_attribution_indices_with_verb_root(doc):
    '''This function specifically addresses cases where a verb is between
       a noun and its modifier. For instance: "a dog that is red"
       here, the aux is between 'dog' and 'red'. '''

    subtrees = []
    modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp", "nummod"]
    for w in doc:
        subtree = []
        stack = []

        # if w is a verb/aux and has a noun child and a modifier child, add them to the stack
        if w.pos_ != 'AUX' or w.dep_ in modifiers:
            continue

        for child in w.children:
            if child.dep_ in modifiers or child.pos_ in ['NOUN', 'PROPN']:
                if child.pos_ not in ['AUX', 'VERB']:
                    subtree.append(child)
                stack.extend(child.children)
        # did not find a pair of noun and modifier
        if len(subtree) < 2:
            continue

        while stack:
            node = stack.pop()
            if node.dep_ in modifiers or node.dep_ == "conj":
                # we don't want to add 'is' or other verbs to the loss, we want their children
                if node.pos_ not in ['AUX']:
                    subtree.append(node)
                stack.extend(node.children)

        if subtree:
            if w.pos_ not in ['AUX']:
                subtree.append(w)
            subtrees.append(subtree)
    return subtrees

def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
    """Utility function to list the indices of the tokens you wish to alter"""
    ids = tokenizer(prompt).input_ids
    indices = {
        i: tok
        for tok, i in zip(
            tokenizer.convert_ids_to_tokens(ids), range(len(ids))
        )
    }
    return indices

def str2num(str_num):
    if str_num == "one":
        return 1
    elif str_num == "two":
        return 2
    elif str_num == "three":
        return 3
    elif str_num == "four":
        return 4
    else:
        return 5
