import nltk  
import random  
from nltk.sentiment import SentimentIntensityAnalyzer  
  
# Initialize SentimentIntensityAnalyzer  
sia = SentimentIntensityAnalyzer()  
  
def extract_yes_no(text):  
    # Tokenize sentences  
    sentences = nltk.sent_tokenize(text)  
  
    # Analyze sentiment scores for each sentence  
    sentiment_scores = [sia.polarity_scores(sentence) for sentence in sentences]  
  
    # Determine whether each sentence has a clear yes or no answer  
    yes_no_unclear = []  
    for score in sentiment_scores:  
        if score['pos'] > 0.5:  
            yes_no_unclear.append('yes')  
        elif score['neg'] > 0.5:  
            yes_no_unclear.append('no')  
        else:  
            yes_no_unclear.append('unclear')  
  
    # Count the number of yes, no, and unclear responses  
    yes_count = yes_no_unclear.count('yes')  
    no_count = yes_no_unclear.count('no')  
    unclear_count = yes_no_unclear.count('unclear')  
  
    # If there are more yes or no answers, return the majority  
    if yes_count > no_count and yes_count > unclear_count:  
        return 'yes'  
    elif no_count > yes_count and no_count > unclear_count:  
        return 'no'  
    else:  
        # If there is no clear majority, return a random answer  
        return random.choice(['yes', 'no'])  


pred_file_address = "F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\implicatures_final.json_final.json"
gold_file_address = "F:\\user-repos\\saurasrivastava\\prompt_gen\\data\\implicatures\\data\\task.json"
from collections import Counter
import json
def read_file(pred_file_address, gold_file_address):
    attempt1_correct, attempt2_correct, attempt3_correct, majority_correct = 0, 0, 0, 0
    pred_jsn = json.load(open(pred_file_address))
    gold_json = json.load(open(gold_file_address))
    gold_data = gold_json["examples"]
    assert len(gold_data) == len(pred_jsn)
    for pred_d, gold_d in zip(pred_jsn, gold_data):
        assert pred_jsn[pred_d]["question"].split("\n")[-1] == gold_d["input"]
        gold_answer = "yes" if gold_d["target_scores"]["yes"] == 1.0 else "no"
        pred_answers = pred_jsn[pred_d]["final_response"]
        predicted_answers = [extract_yes_no(x) for x in pred_answers]
        ##
        attempt1_correct += (gold_answer==predicted_answers[0])
        attempt2_correct += (gold_answer==predicted_answers[1])
        attempt3_correct += (gold_answer==predicted_answers[2])
        dict_count = Counter(predicted_answers)
        majority_vote = "yes" if dict_count["yes"] > dict_count["no"] else "no"
        majority_correct += majority_vote == gold_answer
        ##
        # print(gold_answer, predicted_answers, attempt1_correct, attempt2_correct, attempt3_correct, majority_correct)
    print(f"Total correct 1 count: {attempt1_correct},  %: {attempt1_correct/len(pred_jsn)}\nTotal correct 2 count: {attempt2_correct},  %: {attempt2_correct/len(pred_jsn)}\nTotal correct 3 count: {attempt3_correct}, %{attempt3_correct/len(pred_jsn)}\nMajority Vote count: {majority_correct},  %: {majority_correct/len(pred_jsn)}")
    

if __name__ == "__main__":
    read_file(pred_file_address, gold_file_address)


