from lm_understanding.metrics import GloVeEncoder, most_similar_idx
from lm_understanding.question_template import TemplateModelBehavior

from .explanations import LocalExplanationSet


def make_explanation(counterfactual_question: str, counterfactual_answer: float) -> str:
    return  f'If the question had been the following, the answer would have been {round(counterfactual_answer, 4)}:\n{counterfactual_question}'


def create_counterfactual_explanations(model_behavior: TemplateModelBehavior) -> LocalExplanationSet:
    different_answer_threshold = 0.2
    assert model_behavior.template_id is not None
    encoder = GloVeEncoder()
    questions = model_behavior.questions('train')
    encoded_questions = encoder.encode(model_behavior.questions('train'))
    answer_probs = model_behavior.answers('train')
    maximum_threshold_possible = (answer_probs.max() - answer_probs.min()) / 2
    distance_threshold = min(different_answer_threshold, maximum_threshold_possible)
    explanations = []
    for encoded_question, question_answer_prob in zip(encoded_questions, answer_probs):
        different_answers = abs(answer_probs - question_answer_prob) > distance_threshold
        counterfactual_options = encoded_questions[different_answers]
        nearest_counterfactual = counterfactual_options[most_similar_idx(encoded_question, counterfactual_options, k=1)[0]]
        counterfactual_idx = [(q == nearest_counterfactual).all() for q in encoded_questions].index(True)
        explanations.append(make_explanation(questions[counterfactual_idx], answer_probs[counterfactual_idx]))
    return LocalExplanationSet(model_behavior.template_id, questions, answer_probs.tolist(), explanations)

