import numpy as np


def iterate_dependencies(parent):
    for child in parent.children:
        yield (parent, child)
        for depend in iterate_dependencies(child):
            yield depend

### Stuff for ILASP #####
def process_examples(nlp, examples, **kwargs):
    for premise_text, concepts_text in examples:
        yield formalise_text(nlp, premise_text, concepts_text, **kwargs)
        
def write_examples_to_file(nlp, examples, ofname, sep='=====', **kwargs):
    with open(ofname, 'w') as ofile:
        for k, terms in enumerate(process_examples(nlp, examples, **kwargs)):
            ofile.write(sep + f"example[{k}]"+sep+'\n')
            for term in terms:
                ofile.write(str(term)+'\n')

def formalise_text(nlp, premise_text, concepts_text, include_tokens=False, single_sentence_premise=True):
    sent_counter=0
    if not single_sentence_premise:
        raise ValueError("Do not currently manage multi-sentence premises")
    premise = nlp(premise_text)
    terms = []
    if len(premise) == 0:
        raise ValueError("Empty premise.")
    for sent in premise.sents:
        premise_id = sent_counter
        if single_sentence_premise and premise_id >0:
            raise ValueError(f"Multi-sentence premise passed to single-sentence premise encoder:\n\tpremise: {premise_text}\n\tsentences{[str(sent) for sent in premise.sents]}")
        terms.append(predicate_premise(premise_id))
        tokens = []
        pos_tags = []
        sequence = []
        deps = []
        for tok in sent:
            tokens.append(predicate_token(premise_id, tok.i, f"'{tok.lower_}'"))        
        sequence.append(predicate_sentence_start(premise_id, sent[0].i))
        for a, b in zip(sent[:-1], sent[1:]):
            sequence.append(predicate_sentence_successor(premise_id, a.i, b.i))
        root = sent.root
        deps.append(predicate_doc_root(premise_id, root.i))
        for parent, child in iterate_dependencies(root):
            #if tok.pos_ != "PUNCT":
            deps.append(predicate_doc_dependency(premise_id, parent.i, child.i, child.dep_))
        if include_tokens:
            terms += tokens
        terms += sequence+deps
        sent_counter += 1
    concepts = nlp(concepts_text)
    for sent in concepts.sents:
        concept_id = sent_counter
        terms.append(predicate_concept(premise_id,concept_id))
        premise_toks = list(premise)
        try:
            concept_ref_ids = {tok.i:arg_best_match(tok, premise_toks) for tok in sent}
        except:
            print(f"Error encountered with premise:\n\t{premise}\nand concept sentence\n\t{sent}.")
            raise
        terms.append(predicate_sentence_start(concept_id, concept_ref_ids[sent[0].i]))
        # swapped for the below when we needed to do dependency trees as well
#        for tok_id1, tok_id2 in zip(concept_ref_ids[:-1], concept_ref_ids[1:]):
        for tok1,tok2 in zip(sent[:-1], sent[1:]):
            tok_id1 = concept_ref_ids[tok1.i]
            tok_id2 = concept_ref_ids[tok2.i]
            terms.append(predicate_sentence_successor(concept_id, tok_id1, tok_id2))
        sent_counter += 1
        # construct the tree of the concept referencing the original sentence ids
        concept_deps = formalise_dependency_tree(concept_id, sent, concept_ref_ids)
        terms.extend(concept_deps)
#        root = sent.root
#        root_id = concept_ref_ids[root.i]
#        deps.append(predicate_doc_root(concept_id, root_id))
#        for parent, child in iterate_dependencies(root):
#            parent_id = concept_ref_ids[parent.i]
#            child_id = concept_ref_ids[child.i]
#            #if tok.pos_ != "PUNCT":
#            deps.append(predicate_doc_dependency(concept_id, parent_id, child_id, child.dep_))
    return terms


def formalise_dependency_tree(sent_id, sent, reference_ids=None):
    """
    sent - nlp sentence
    reference_ids (optional) - Dictionary from this sentence position to a 
        reference index. It allows you to rename the tokens by some other system.
        Useful when wanting to construct a tree of a concept using the indices of 
        the toekns in the premise.
    """
#    print(f"sent={sent}")
#    print(f"reference_ids={reference_ids}")
    deps = []
    if reference_ids == None:
        # if no reference_ids are passed then it is the identity function
        reference_ids = list(range(len(sent)))
    root = sent.root
#    print(f"root = {root}")
#    print(f"root.i = {root.i}")
    root_id = reference_ids[root.i]
    deps.append(predicate_doc_root(sent_id, root_id))
    for parent, child in iterate_dependencies(root):
        parent_id = reference_ids[parent.i]
        child_id = reference_ids[child.i]
        #if tok.pos_ != "PUNCT":
        deps.append(predicate_doc_dependency(sent_id, parent_id, child_id, child.dep_))
    return deps

def predicate_premise(sent_id):
    return f"premise(sent{sent_id})."

def predicate_concept(premise_id, concept_id):
    return f"concept(sent{premise_id}, sent{concept_id})."

def predicate_token(sent_id, tok_id, tok_text):
    return f"token(sent{sent_id}, tok{tok_id}, {tok_text})."

def predicate_sentence_start(sent_id, tok_id):
    return f"start(sent{sent_id}, tok{tok_id})."

def predicate_sentence_successor(sent_id, tok1_id, tok2_id):
    return f"succ(sent{sent_id}, tok{tok1_id}, tok{tok2_id})."

def predicate_doc_root(sent_id, root_id):
    return f"root(sent{sent_id}, tok{root_id})."

def predicate_doc_dependency(sent_id, parent_id, child_id, dep):
    return f"dep({dep}, sent{sent_id}, tok{parent_id}, tok{child_id})."

# syntactic comparisons
def arg_best_match(anchor, toks, position_factor=1):
    """
    Compares anchor separately with each element of toks and determines which is closer.
    Returns index of closest token. If they match returns lowest such index
    Assumes that both have the same text. If tok[i].text != tok[j].text then all
    bets are off.
    """
    scores = [token_match_score(anchor, tok, position_factor) for tok in toks]
    besti = np.argmax(scores)
    if scores[besti] == 0:
        raise ValueError(f"Cannot find good match for {anchor.text}")
    return besti
    

def token_match_score(tok1, tok2, position_factor=1):
    score = 0
    if tok1.lower_ != tok2.lower_:
        return score
    score += 1
    if  tok1.dep != tok2.dep and tok1.dep_ != 'ROOT' and tok2.dep_ != 'ROOT':
        return score
    score += 1
    score += token_ancestor_match(tok1, tok2)
    score += token_child_match(tok1, tok2)
    if tok1.i == tok2.i:
        score += position_factor
    if (len(tok1.doc)-tok1.i) == (len(tok2.doc)-tok2.i):
        score += position_factor
    return score

def token_ancestor_match(tok1, tok2):
    i = -1
    for i, (a1, a2) in enumerate(zip(tok1.ancestors, tok2.ancestors)):
        if a1.text != a2.text:
            return i
    return i+1

def token_child_match(tok1, tok2):
    i = -1
    for i, (a1, a2) in enumerate(zip(tok1.children, tok2.children)):
        if a1.text != a2.text:
            return i
    return i+1


### Deprecated ### 
def formalise_text_depr(nlp, text, sent_counter=0):
    doc = nlp(text)
    terms = []
    if len(doc) == 0:
        return terms, sent_counter
    for sent in doc.sents:
        tokens = []
        lemmas = []
        pos_tags = []
        sequence = []
        deps = []
        for tok in sent:
            if tok.pos_ != "PUNCT":
                tokens.append(f"token(sent{sent_counter}, tok{tok.i}, {tok.lower_}).")
                lemmas.append(f"lemma({tok.lower_}, {tok.lemma_}).")
            else:
                tokens.append(f"token(sent{sent_counter}, tok{tok.i}, punct).")
            #print(f"tag(sent{sent_count}, tok{tok.i}, {tok.tag_}).", end=" ")
            #pos_tags.append(f"pos(sent{sent_count}, tok{tok.i}, {tok.pos_}).")
        
        sequence.append(f"start(sent{sent_counter}, tok{sent[0].i}).")
        for a, b in zip(sent[:-1], sent[1:]):
            sequence.append(f"succ(sent{sent_counter}, tok{a.i}, tok{b.i}).")
        root = sent.root
        deps.append(f"root(sent{sent_counter}, tok{root.i}).")
        for depend in iterate_dependencies(root):
            #if tok.pos_ != "PUNCT":
            deps.append(f"dep_{depend[-1].dep_}(sent{sent_counter}, tok{depend[0].i}, tok{depend[1].i}).")
        terms += tokens+lemmas+sequence+deps#+pos_tags
        sent_counter += 1
    return terms, sent_counter

def formalise_premise_and_concepts_depr(nlp, premise, concepts, non_concepts):
    premise_terms, sent_counter = formalise_text(nlp, premise)
    for i in range(sent_counter):
        premise_terms.append(f"premise(sent{i}).")
    first_concept = sent_counter
    concept_terms, sent_counter = formalise_text(nlp, concepts, sent_counter=sent_counter)
    for i in range(first_concept, sent_counter):
        concept_terms.append(f"concept(sent{i}).")
    first_non_concept = sent_counter
    non_concept_terms, sent_counter = formalise_text(nlp, non_concepts, sent_counter=sent_counter)
    for i in range(first_concept, sent_counter):
        non_concept_terms.append(f"non_concept(sent{i}).")
    return premise_terms, concept_terms, non_concept_terms
