"""
EM based prefrence pair construction
"""

from metrics.answer import AnswerMetric, compute_exact, metric_max_over_ground_truths, substring_exact_match_score
from utils import *
from tqdm import tqdm
from datasets import Dataset
from functools import partial
from datasets import load_from_disk

import argparse
import json
import random
import os

def extract_ans_from_response(raw_response, split_seg_list = ["The answer is:", "The answer is", "the answer is:", "the answer is"]):
    potentinal_ans = raw_response
    for split_seg in split_seg_list:
        if split_seg in potentinal_ans:
            potentinal_ans = potentinal_ans.split(split_seg)[-1].split("\n\n")[0]
            for _ in ["."]:
                if _ in potentinal_ans:
                    potentinal_ans = potentinal_ans.split(_)[0]
    return potentinal_ans

def em_check_correctness(question, reply, ground_truth_answers, split_seg_list = ["The answer is:", "The answer is", "the answer is:", "the answer is"]):
    predicted_answer = extract_ans_from_response(reply, split_seg_list)
    exact_scores = metric_max_over_ground_truths(
            substring_exact_match_score, predicted_answer, ground_truth_answers
        )
    return bool(exact_scores)


def selecte_by_golden_answer(data,correctness_cheking="EM"):
    question = data["question"]
    predictions = data["predictions"]
    answer = data.get("answer", None)
    answer_aliases = data.get("answer_aliases", [])
    possible_answers = [answer] if answer else [] + answer_aliases

    if correctness_cheking == "EM":
        check_correctness_function = em_check_correctness
    
    correct_replies = []
    incorrect_replies = []
    predicted_answer_list = []
    for reply in predictions:
        predicted_answer = extract_ans_from_response(reply)
        predicted_answer_list.append(predicted_answer)

        if check_correctness_function(question, reply, possible_answers, split_seg_list=["The answer is:", "The answer is"]):
            correct_replies.append(reply)
        else:
            if correctness_cheking == "EM":
                if not any([ans.lower() in reply.lower() for ans in possible_answers]):
                    incorrect_replies.append(reply)
    
    data["positive_sample_list"] = correct_replies
    data["negative_sample_list"] = incorrect_replies
    data["predicted_answer_list"] = predicted_answer_list

    positive_sample = random.choice(correct_replies) if len(correct_replies) > 0 else None
    negative_sample = random.choice(incorrect_replies) if len(incorrect_replies) > 0 else None
    data["positive_sample"] = positive_sample
    data["negative_sample"] = negative_sample

    data["success"] = positive_sample is not None and negative_sample is not None
    data["all_incorrect"] = positive_sample is None
    data["all_correct"] = negative_sample is None

    return data


def main_pair_selection(args):
    """preference pair construction"""
    src_data_dir_path = args.src_data_dir_path
    data_name = args.src_file_name
    selection_type = args.selection_type.split(",")

    src_path = f"{src_data_dir_path}/{data_name}"
    correctness_cheking = args.evaluation_metric
    for select_parts in selection_type:
        # load data
        if src_path.endswith(".jsonl"):
            src_dataset = load_custom_dataset(src_path)
        else:
            src_dataset = load_from_disk(src_path)
        
        ## If the pair originally exists, it is considered a pair sampled from short context by default.
        if "positive_sample" in src_dataset.column_names:
            src_dataset = src_dataset.rename_column("positive_sample", "positive_sample_on_short")
        if "negative_sample" in src_dataset.column_names:
            src_dataset = src_dataset.rename_column("negative_sample", "negative_sample_on_short")

        print(src_dataset)
        # data format: 'question', 'answer'/"answer_aliases", 'predictions'
        if "outputs" in src_dataset.column_names and "answer_aliases" not in src_dataset.column_names:
            src_dataset = src_dataset.rename_columns({"outputs": "answer_aliases"})
        # data processing
        process_func = partial(selecte_by_golden_answer,correctness_cheking=correctness_cheking)
        processed_dataset = src_dataset.map(process_func, num_proc=1)
        # save
        print("pair construction success: ", sum(processed_dataset["success"]))
        print("final answer is all correct: ", sum(processed_dataset["all_correct"]))
        print("final answer is all incorrect: ", sum(processed_dataset["all_incorrect"]))

        processed_dataset = processed_dataset.remove_columns(["predictions"])
        if select_parts == 'success':
            processed_dataset = processed_dataset.filter(lambda x: x["success"])
        if select_parts == 'all_correct':
            processed_dataset = processed_dataset.filter(lambda x: x["all_correct"])
        if select_parts == 'all_incorrect':
            processed_dataset = processed_dataset.filter(lambda x: x["all_incorrect"])
        print(processed_dataset)
        if src_path.endswith(".jsonl"):
            tgt_path = src_path.replace(".jsonl", f"_all_{correctness_cheking}_selector_{select_parts}.jsonl")
            processed_dataset.to_json(tgt_path, force_ascii=False)
        else:
            tgt_path = f"{src_data_dir_path}/{data_name}_all_{correctness_cheking}_selector_{select_parts}"
            processed_dataset.save_to_disk(tgt_path)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--src_data_dir_path", type=str)
    parser.add_argument("--src_file_name", type=str)
    parser.add_argument("--evaluation_metric", type=str, default="EM")
    parser.add_argument("--selection_type", type=str, default="success", help="success,all_incorrect,all_incorrect")
    return parser.parse_args()

def main():
    args = get_args()
    main_pair_selection(args)

if __name__ == "__main__":
    main()
    