import json
import re
import pandas as pd
import yaml
import hotpot_evaluate

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

df = pd.DataFrame(columns=['worker_id', 'model1', 'model2', 'eval_model', 'fluency', 'helpfulness', 'ease', 'helpfulness_freetext', 'no_of_turns', 'accuracy', 'f1', 'recall'])

conv_file_name = '../results/hotpot_conversation_{model1}_{model2}{persona}_prompt-2.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(conv_file_name, 'r') as f:
    conv_data = json.load(f)
    conv_data = {c['id']:c for c in conv_data}

#file_name = '../results/predictions_{model1}_{model2}_{eval_model}.json'.format(
file_name = '../results/hotpot_predictions_{model1}_{model2}_{eval_model}{persona}_prompt-2.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(file_name, 'r') as f:
    predictions = json.load(f)

def parse_line(line):
    if ':' in line:
        match = re.search(' \((\d)\)', line)
        if match is not None:
            number = int(match.group(1))
        else:
            number = int(line.split(':')[-1].strip())
    else:
        number = int(line.strip()[-2])
    return number

for idx, pred in enumerate(predictions):
    pred_text = pred['prediction'].strip()
    pred_text = re.sub('\n\n', '\n', pred_text)
    pred_text = pred_text.split('\n')
    try:
        fluency, helpful, ease = pred_text[:3]
        print(fluency)
        print(helpful)
        print(ease)
        fluency = parse_line(fluency)
        helpful = parse_line(helpful)
        ease = parse_line(ease)
    except Exception as e:
        print(pred_text)
        if '' == ''.join(pred_text).strip():
            fluency, helpful, ease = 0, 0, 0
        else:
            fluency = int(input('fluency: '))
            helpful = int(input('helpful: '))
            ease = int(input('ease: '))
    try:
        helpful_text = '\n'.join(pred_text[3:])
        print(helpful_text)
    except Exception as e:
        helpful_text = ''
    if 'helpfulness:' == helpful_text.lower()[:12]:
        helpful_text = helpful_text[12:].strip()
    elif 'helpfulness (free-form): ' == helpful_text.lower()[:25]:
        helpful_text = helpful_text[25:].strip()
    else:
        helpful_text = helpful_text.strip()

    # number of turns
    worker_id = pred['worker_id']
    line = conv_data[worker_id]
    no_of_turns = len(line['lm_responses'])
    print('# of turns:', no_of_turns)

    # accuracy
    golden_answer = line['answer']
    user_answer = line['user_answer']
    flag = (hotpot_evaluate.normalize_answer(golden_answer) in
            hotpot_evaluate.normalize_answer(user_answer))
    f1, _, recall = hotpot_evaluate.f1_score(user_answer, golden_answer)
#    if not flag:
#        print(', '.join(golden_answers))
#        print(user_answer)
#        acc = int(input('0 or 1: '))
    acc = int(flag)
    print('Accuracy:', acc)
    print('F1:', f1)
    print('Recall:', recall)

    row = [idx+1, config['model1'], config['model2'], config['eval_model'], fluency, helpful, ease, helpful_text, no_of_turns, acc, f1, recall]
    df.loc[len(df.index)] = row

#file_name = '../results/predictions_{model1}_{model2}_{eval_model}.csv'.format(
file_name = '../results/hotpot_predictions_{model1}_{model2}_{eval_model}{persona}_prompt-2.csv'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
df.to_csv(file_name, index=False)

