import spacy
nlp = spacy.load("en_core_web_sm")
from openai_call import get_default_kwargs_for_openai_call, get_openai_response

from transformers import (
    TokenClassificationPipeline,
    AutoModelForTokenClassification,
    AutoTokenizer,
)
from transformers.pipelines import AggregationStrategy

class KeyphraseExtractionPipeline(TokenClassificationPipeline):
    def __init__(self, model, *args, **kwargs):
        super().__init__(
            model=AutoModelForTokenClassification.from_pretrained(model),
            tokenizer=AutoTokenizer.from_pretrained(model),
            *args,
            **kwargs
        )

    def postprocess(self, all_outputs):
        results = super().postprocess(
            all_outputs=all_outputs,
            aggregation_strategy=AggregationStrategy.SIMPLE,
        )
        return results


model_name = "ml6team/keyphrase-extraction-kbir-kpcrowd"
# model_name = "ml6team/keyphrase-extraction-kbir-inspec"
extractor = KeyphraseExtractionPipeline(model=model_name)


def get_keyphrases_kbir(text):
    text = text.replace("\n", " ")
    return extractor(text)

# -----------------------


def get_entities_spacy(text):
    doc = nlp(text)
    entities = []
    for ent in doc.ents:
        entities.append({
            "start": ent.start_char,
            "end": ent.end_char,
            "word": ent.text
        })
    return entities


# -----------------

def find_positions_of_keywords(sentence, keywords):
    keywords_with_positions = []
    for keyword in keywords:
        start_index = sentence.find(keyword.strip())
        if(start_index != -1):
            end_index = start_index + len(keyword.strip())
            keywords_with_positions.append({
                "word": keyword,
                "start": start_index,
                "end": end_index
            })

    return sorted(keywords_with_positions, key=lambda d: d['start'])


def get_keyphrases_prompting(text):
    kwargs = get_default_kwargs_for_openai_call()
    # kwargs["prompt"] = text + "\n" + "Identify all the important keywords from the above sentence and return a comma separated list."
    kwargs["prompt"] = text + "\n" + "Identify all the important keyphrases from the above sentence and return a comma separated list."

    response_text = get_openai_response(kwargs, return_complete_response=False)

    keywords = response_text.split(",")
    return find_positions_of_keywords(text, keywords)


# -----------------

def find_tokens_for_keywords(probability_output):
    sentence_generated_by_model = probability_output["sentence"]

    # find keywords in the generated sentence
    # keywords_output = get_keyphrases_kbir(sentence_generated_by_model)
    keywords_output = get_keyphrases_prompting(sentence_generated_by_model)

    start_positions_keywords = []
    end_positions_keywords = []
    keywords = []

    for word in keywords_output:
        start_positions_keywords.append(word["start"])
        end_positions_keywords.append(word["end"])
        keywords.append(word["word"])

    print("All Identified keywords: ", keywords)

    char_positions_of_tokens = probability_output["char_positions"]
    probability_of_tokens = probability_output["all_probs"]
    all_tokens = probability_output["all_tokens"]

    index_of_token = 0
    list_of_keywords = []

    for index_of_keyword in range(len(keywords)):
        start_position_keyword = start_positions_keywords[index_of_keyword]
        end_position_keyword = end_positions_keywords[index_of_keyword]

        while(index_of_token < len(char_positions_of_tokens) and start_position_keyword != char_positions_of_tokens[index_of_token][0]):
            index_of_token += 1

        token_probs = []
        tokens = []
        while (index_of_token < len(char_positions_of_tokens) and end_position_keyword > char_positions_of_tokens[index_of_token][1]):
            token_probs.append(probability_of_tokens[index_of_token])
            tokens.append(all_tokens[index_of_token])
            index_of_token += 1

        if(index_of_token < len(char_positions_of_tokens) and end_position_keyword <= char_positions_of_tokens[index_of_token][1]):
            tokens.append(all_tokens[index_of_token])
            token_probs.append(probability_of_tokens[index_of_token])

        if(len(tokens) > 0):
            list_of_keywords.append({
                "keyword": keywords[index_of_keyword],
                "probs": token_probs,
                "tokens": tokens,
            })

    return list_of_keywords


