import copy
import numpy as np
from openai_faithfulness_pipeline import final_answer_str, prefix
from utils.util import load_config, get_multiple_choice_question, get_accuracy_matrix
from utils.faithfulness import get_faithfulness_matrix

def get_qa_prompt(response):
    """
    Get the multiple choice question (Q) with the final answer prompt (A) appended to the end, for example:
    ```Question: What is the capital of France?

    Choices:
    A) Paris
    B) London
    C) Berlin
    D) Madrid
    
    Final Answer: The single, most likely answer is (```
    """
    mcq_question = get_multiple_choice_question(response['question'], response['options'])
    return mcq_question + '\n\n' + response['final_answer_str']

def get_qe_prompt(response):
    """
    Get the multiple choice question (Q) with the explanation prompt (E) appended to the end, for example:
    ```Question: What is the capital of France?

    Choices:
    A) Paris
    B) London
    C) Berlin
    D) Madrid
    
    Step 1: ```
    """
    mcq_question = get_multiple_choice_question(response['question'], response['options'])
    return mcq_question + prefix

def get_label_response(response, sample_idx=None):
    return response['label']

def get_ea_response(response, sample_idx=0):
    """
    Get the cot response text (E) with the final answer and letter appended to the end (A), for example:
    ```Step 1: Paris is the capital of France.
    
    Final Answer: The single, most likely answer is (B)```
    """
    final_answer_probs = response[f'sample_{sample_idx}']['final_answer_probabilities']
    letter = max(final_answer_probs, key=final_answer_probs.get)
    response_text = response[f'sample_{sample_idx}']['full_response'].lstrip()
    return response_text[:response_text.find('Final Answer:')] + final_answer_str + letter + ')'

def get_max_answer(response, sample_idx):
    final_answer_probs = response[f'sample_{sample_idx}']['intermediate_answer_probabilities'][-1]
    return max(final_answer_probs, key=final_answer_probs.get)

def get_max_answer_prob(response, sample_idx):
    final_answer_probs = response[f'sample_{sample_idx}']['final_answer_probabilities']
    return max(final_answer_probs.values())

def get_num_cot_steps(response, sample_idx):
    num_cot_steps = 0
    while f'step_{num_cot_steps+1}' in response[f'sample_{sample_idx}']:
        num_cot_steps += 1
    return num_cot_steps

def get_ft_examples(responses_dir, top_percent, prompt_fn, response_fn, filter_correct=False):
    """
    Get the finetuning examples for a given set of responses, using the given prompt/response functions.
    :param responses_dir: The directory containing the responses
    :param top_percent: The top percentage of responses to use
    :param prompt_fn: The function to generate the prompt text from a given response
    :param response_fn: The function to generate the response text from a given response and sample index
    :return: A list of finetuning examples containing the prompt text, response text, response id, sample id, and faithfulness score
    """
    # Get all 400 training responses and faithfulness matrix
    responses = [load_config(responses_dir + f'response_{i}.json') for i in range(400)]

    # Get the faithfulness matrix
    faith_matrix = get_faithfulness_matrix(responses_dir)[0]
    original_faith_matrix = copy.deepcopy(faith_matrix)

    # Filter out correct responses if specified
    if filter_correct:
        acc_matrix = get_accuracy_matrix(responses_dir)
        faith_matrix[acc_matrix==0] = -1
        n_correct = acc_matrix.any(1).sum()
        n = n_correct * top_percent // 100
    else:
        n = faith_matrix.shape[0] * top_percent // 100
        
    # Get indices of best training examples
    response_idxs = faith_matrix.max(1).argsort()[-n:]

    # Get the corresponding indices of best samples per training example
    sample_idxs = faith_matrix.argmax(1)[response_idxs]

    # Shuffle both arrays (with the same permutation)
    np.random.seed(42); idxs = np.random.permutation(n)
    response_idxs, sample_idxs = response_idxs[idxs], sample_idxs[idxs]

    labels = [response['label'] for response in responses]

    # Get the finetuning examples using the prompt and response functions
    ft_examples = [{
        "prompt_text": prompt_fn(responses[response_idx]),
        "response_text": response_fn(responses[response_idx], sample_idx),
        "response_id": f"response_{response_idx}.json",
        "sample_id": f"sample_{sample_idx}",
        "label": labels[response_idx],
        "llm_label": get_max_answer(responses[response_idx], sample_idx),
        "llm_label_prob": get_max_answer_prob(responses[response_idx], sample_idx),
        "faithfulness": original_faith_matrix[response_idx, sample_idx],
        "num_cot_steps": get_num_cot_steps(responses[response_idx], sample_idx)}
        for response_idx, sample_idx in zip(response_idxs, sample_idxs)]
    
    return ft_examples