'''
- bert.py
- This file handles text annotation using bert
'''

# External imports
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AlbertForQuestionAnswering, AlbertTokenizerFast

# Internal imports
from src.core.configuration.annotation_conf import *
from src.utils.misc.io import *



'''
----------load_model----------
- Loads BERT for later use in VIDS
-----Inputs-----
- use_default_tokenizer - Boolean value of whether or not to use the default pretrained tokenizer (Defaults to True)
-----Output-----
- loaded_model - The loaded model and tokenizer for BERT
'''
def load_model(use_default_tokenizer=False, task='ner', path=None):
    
    if task == 'ner':
        model_location = ALBERT_LOC if path == None else path
        if use_default_tokenizer:
            tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_location)
        model = AutoModelForTokenClassification.from_pretrained(model_location)
    elif task == 'squad':
        model_location = ALBERT_LOC_SQUAD if path == None else path
        if use_default_tokenizer:
            tokenizer = AlbertTokenizerFast.from_pretrained("albert-base-v2")
        else:
            tokenizer = AlbertTokenizerFast.from_pretrained(model_location)
        model = AlbertForQuestionAnswering.from_pretrained(model_location)
    return {'tokenizer':tokenizer,'model':model}


'''
----------annotate----------
- Annotates the given input using BERT for the supplied feature
-----Inputs-----
- text - The user input to parse
- active_feature - The currently-active feature in the PeTEL expression
- ner - The loaded named entity recognition technique
-----Output-----
- annotated_text - The annotated user input for the given feature
'''
def annotate(text, active_feature, model, debug=True, task='ner', question=''):
    label_sequence = get_labels(text, model, task, question)
    if debug:
        print_list(label_sequence)
    if task == 'ner':
        annotated_sentence = get_phrases_from_labels(label_sequence, active_feature)
    elif task == 'squad':
        return [{"text":label_sequence["answer"], "confidence":label_sequence["score"]}]
        # return [{"text":label_sequence["answer"], "confidence":label_sequence["score"]}] if label_sequence["score"] > 0.3 else []

    # JOINT RANKING: To revert, comment out the above line, and uncomment the below line
    #annotated_sentence = get_features_from_labels(label_sequence, active_feature)
    return annotated_sentence
    
    # if annotated_sentence:
    #     return annotated_sentence
    # return [{"text":text,"confidence":1}]

    
    # JOINT RANKING: To revert, comment out the above line, and uncomment the below line
    #return text


'''
----------get_labels----------
- Extracts the label sequence from the input text
-----Inputs-----
- label_sequence - The label sequence to parse
- feature - The feature that is to be extracted
-----Output-----
- annotated_text - The annotated user input for the given feature
'''
def get_labels(text, model, task='ner', question=''):

    if task == 'ner':
        ner_model = pipeline(
                            'ner', model=model["model"], 
                            tokenizer=model["tokenizer"], 
                            grouped_entities=True
                        )
        label_seq = ner_model(text)
    elif task == 'squad':
        qa_pipeline = pipeline(
                                "question-answering",
                                model=model["model"],
                                tokenizer=model["tokenizer"]
                            )
        
        predictions = qa_pipeline({
                            'context': text,
                            'question': question
                        })

        label_seq = predictions
    # Test the sequence
    return label_seq


'''
----------get_features_from_labels----------
- Extracts the appropriate features from the label sequence based on a given feature
-----Inputs-----
- label_sequence - The label sequence to parse
- feature - The feature that is to be extracted
-----Output-----
- annotated_text - The annotated user input for the given feature
'''
def get_features_from_labels(label_sequence, feature):
    # Get labels based on desired feature
    if (feature == 'entity'):
        labels = ['LABEL_9','LABEL_10']
    elif (feature == 'attribute'):
        labels = ['LABEL_11','LABEL_12']
    elif (feature == 'aggregator'):
        labels = ['LABEL_13','LABEL_14']
    elif (feature == 'filter'):
        labels = ['LABEL_15','LABEL_16']
    elif (feature == 'filter_operation'):
        labels = ['LABEL_17','LABEL_18']
    elif (feature == 'prediction_window'):
        labels = ['LABEL_19','LABEL_20']
    elif (feature == 'number'):
        labels = ['LABEL_21','LABEL_22']

    # Iterate over the label sequence, concatenating the desired labels together
    result = ''
    for x in label_sequence:
        if x['entity_group'] in labels:
            result += x['word']

    # Return the resulting string
    return result


'''
----------get_phrases_from_labels----------
- Extracts the appropriate phrases from the label sequence based on a given feature
-----Inputs-----
- label_sequence - The label sequence to parse
-----Output-----
- phrases - The possible phrases given the label sequence
'''
def get_phrases_from_labels(label_sequence, feature):
    phrases = []
    # Iterate over the label sequence, examining each 'token cluster'
    for x in label_sequence:
        # If the label is at all useful (not labelled as "O"), count it as a possible phrase
        #if x["entity_group"] != 'LABEL_0':
        #    phrases.append({"text":x["word"], "confidence":x["score"]})

        # Get labels based on desired features
        if ((feature == 'entity') & (x["entity_group"] in ['LABEL_9','LABEL_10'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        elif ((feature == 'attribute') & (x["entity_group"] in ['LABEL_11','LABEL_12'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        elif ((feature == 'aggregator') & (x["entity_group"] in ['LABEL_13','LABEL_14'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        elif ((feature == 'filter') & (x["entity_group"] in ['LABEL_15','LABEL_16'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        elif ((feature == 'filter_operation') & (x["entity_group"] in ['LABEL_17','LABEL_18'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        elif ((feature == 'prediction_window') & (x["entity_group"] in ['LABEL_19','LABEL_20'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        elif ((feature == 'number') & (x["entity_group"] in ['LABEL_21','LABEL_22'])):
            phrases.append({"text":x["word"], "confidence":x["score"]})
        

    
    # Return the possible phrases
    return phrases