import json
import pandas as pd
import yaml

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

from tqdm import tqdm

import openai_api
import anthropic_api
from persona import persona_eval_sys_ambig

sys_message = (
'A user asks an AI a question that may be ambiguous. The AI assistant tries to dis-ambiguous and solve the question in a conversation.'
'You are a helpful and precise evaluator for checking the quality of the AI assistant\'s responses in conversations.'
)

sys_prompt = (
'Please evaluate the above conversations between user and AI assistant by using the following metrics:\n'
'Fluency (5-point Likert): How clear (or fluent) were the responses from the AI Assistant?\n'
'Helpfulness (5-point Likert): Independent of its fluency, Independent of its fluency, how helpful was the AI Assistant to dis-ambiguous and solve the problem?\n'
'Ease of interaction (5-point Likert): How easy was it to interact with the AI Assistant?\n'
'Helpfulness (free-form): Why did you find the AI Assistant helpful or unhelpful?\n'
'Please output each of the above metrics line-by-line.'
)
#'how helpful was having access to the AI Assistant compared to not having access?\n'
#if 'general' != config['persona']:
#    sys_prompt = '\n'.join([persona_eval[config['persona']], sys_prompt])

#data = pd.read_csv('../data/event_blocks.csv')
#workers = pd.read_csv('../results/accuracy_by_id.csv')
file_name = '../results/ambig_conversation_{model1}_{model2}{persona}_prompt-1.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(file_name, 'r') as f:
    data = json.load(f)

def extract_line(line):
    question_text = 'Question: ' + line['question']
    answers = []
    for pairs in line['annotations']:
        if 'qaPairs' in pairs:
            for pair in pairs['qaPairs']:
                answers.extend(pair['answer'])
        elif 'answer' in pairs:
            answers.extend(pairs['answer'])
    answers = list(set(answers))
    answers = [s.strip() for s in answers]
    answers = '\n'.join(answers)
    answer_golden = 'All Acceptable Answers:\n{ans}'.format(ans=answers)
    qa_pair = '\n'.join([question_text, answer_golden])

    conversation = ['Conversation:']
    for query, response in zip(line['user_queries'], line['lm_responses']):
        turn = 'User: {up}\nAI Assistant: {ar}'.format(up=query, ar=response)
        conversation.append(turn)
#    conversation.append('User: This is the answer I want.')
    conversation = '\n'.join(conversation)

#    answer_user = 'User Answer: {ans}'.format(ans=line['user_answer'])

    message = '\n\n'.join([qa_pair, conversation])
    line['model_message'] = message
    return line

def extract(line):
    messages = [{'role': 'system', 'content': sys_message}]

    global sys_prompt
    if 'general' != config['persona']:
#        worker_id = group['worker_id'].unique()[0]
#        rate = workers.loc[workers['worker_id'] == worker_id]['rate'].unique()[0]
#        if rate > 0.6:
#            persona_type = 'type5'
#        elif rate > 0.3:
#            persona_type = 'type6'
#        else:
#            persona_type = 'type4'
#        persona_type = 'type4' if rate < 0.7 else 'type5'
#        sys_prompt = '\n'.join([persona_eval[persona_type], sys_prompt])
        sys_prompt = persona_eval_sys_ambig[config['persona']]

    line = extract_line(line)
    prompt = '\n\n'.join([line['model_message'], sys_prompt])
    messages.append({'role': 'user', 'content': prompt})

    if 'claude' == config['eval_model'][:6]:
        prediction = anthropic_api.call(messages)
    elif 'gpt-' == config['eval_model'][:4]:
        prediction = openai_api.call_chat(messages, config['eval_model'])
    else:
        prompt = '\n\n'.join([messages[0]['content'], messages[1]['content']])
        prediction = openai_api.call_completion(prompt, config['eval_model'])
    line = {'worker_id': line['id'],
            'question': line['question'],
            'prediction': prediction,
            'no_of_turns': len(line['user_queries']),
            }
    return line

predictions = []
for d in tqdm(data):
    predictions.append(extract(d))

#with open('../results/predictions_claude.json', 'w') as f:
#    json.dump(predictions, f, indent=2)
#file_name = '../results/predictions_{eval_model}_{persona}_prompt-2_first-60.json'.format(
#        eval_model=config['eval_model'],
#        persona=config['persona']
#        )
file_name = '../results/ambig_predictions_{model1}_{model2}_{eval_model}{persona}_prompt-1.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(file_name, 'w') as f:
    json.dump(predictions, f, indent=2)

