from utils.scorer.test_score_functions import VALUE_LIST, ROOT_PATH
from utils.scorer.test_score_functions import load_single_file
from utils.scorer.score_functions import get_bert_score, get_bleu_score, get_batch_cosine_similarity, get_diag_gpt4_similarity_exp1

from joblib import Parallel, delayed
import json
import os
import concurrent.futures
import functools
import time

NEW_VALUE_LIST = VALUE_LIST[:-1]
EXP1_PATH = f"{ROOT_PATH}/results/exp1/"

def question_to_id():
    question_path = f'{ROOT_PATH}/dataset/exp1_question.jsonl'
    res = {}
    with open(question_path, 'r', encoding='utf-8') as f:
        for id, line in enumerate(f):
            entry = json.loads(line.strip())
            question = entry.get('question')
            res[question] = id
    return res

def question_value_response():
    file_path = f'{ROOT_PATH}/dataset/exp1_question_answer.jsonl'
    output_dict = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            entry = json.loads(line.strip())
            question_id = entry.get('question')
            value = entry.get('value')
            response = entry.get('response')
            
            if question_id not in output_dict:
                output_dict[question_id] = {}
            
            output_dict[question_id][value] = response

    return output_dict

def process_and_save(model_name, value_type="no", scorer=None, metric_name=""):
    save_path = f"{EXP1_PATH}{metric_name}/{value_type}"
    output_file = f'{save_path}/{model_name}.jsonl'
    if os.path.exists(output_file):
        print(f"{output_file} already exists!")
        return None
    
    qvr = question_value_response()
    input_list = load_single_file(model_name, value_type, is_parse=False, is_exp1=True)
    if input_list is None:
        return
    new_list = []
    s = time.time()
    for item in input_list:
        new_dict = {}
        question  = item.get('question', 0)
        llm_answer = item.get('answer', 0)
        new_dict['question'] = question
        new_dict['llm_answer'] = llm_answer

        responses = [qvr[question][value] for value in NEW_VALUE_LIST if value in qvr[question]]
        scores = scorer(llm_answer, responses)
        for id, value in enumerate(NEW_VALUE_LIST):
            new_dict[f'{value}_score'] = float(scores[id])
        new_list.append(new_dict)
    e = time.time()
    print(f"Time: {e-s}s")

    # save results
    
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    with open(f'{save_path}/{model_name}.jsonl', 'w') as f:
        for item in new_list:
            f.write(json.dumps(item) + '\n')

def calculate_result(model_list, name_list, value_list, value_type, metric_name):
    score_labels = {f"{value}_score": 0 for value in value_list} 
    average_scores = {}

    for model, name in zip(model_list, name_list):
        model_scores = {label: 0 for label in score_labels}
        num_entries = 0

        with open(f'{EXP1_PATH}{metric_name}/{value_type}/{model}.jsonl', 'r') as f:
            for line in f:
                data = json.loads(line)
                num_entries += 1

                for label in score_labels:
                    model_scores[label] += data.get(label, 0)
        
        average_scores[name] = {label: score/num_entries for label, score in model_scores.items()}
    
    # generate LaTeX code
    latex_code = r"\begin{table}[!htbp]" + "\n"
    latex_code += r"\centering" + "\n"
    latex_code += r"\caption{"
    if value_type=='no':
        latex_code += f"Average Scores (No induction) for Different Models"
    else:
        latex_code += f"Average Scores ({value_type}) for Different Models"
    latex_code += "}" + "\n"
    latex_code += r"\begin{tabular}{l" + "c"*len(name_list) + "}" + "\n"
    latex_code += r"\toprule" + "\n"
    latex_code += "Score & " + " & ".join(name_list) + r"\\" + "\n"
    latex_code += r"\midrule" + "\n"

    for label in value_list:
        latex_code += label + " & " + " & ".join([f"{average_scores[model][f'{label}_score']:.2f}" for model in name_list]) + r"\\" + "\n"

    latex_code += r"\bottomrule" + "\n"
    latex_code += r"\end{tabular}" + "\n"
    latex_code += r"\end{table}" + "\n"

    save_path = f"{EXP1_PATH}{metric_name}"
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    with open(f'{save_path}/exp1_{value_type}.tex', 'w') as f:
        f.write(latex_code)



if __name__ == '__main__':
    MODEL_LIST = ['Llama2-7B-chat']
    MODEL_NAME_LIST = ['Llama2-7B']
    TEST_METRIC = functools.partial(get_diag_gpt4_similarity_exp1, auto=False)
    METRIC_NAME = "gpt4_sim"
    
    calculate_score = True
    generate_latex = True
   
    if calculate_score:
        tasks = []
        for model in MODEL_LIST:
            for value in VALUE_LIST:
                tasks.append((model, value, TEST_METRIC, METRIC_NAME))

        Parallel(n_jobs=1)(delayed(process_and_save)(*task) for task in tasks)

    if calculate_score:
        for model in MODEL_LIST:
            for value in VALUE_LIST:
                process_and_save(model, value, scorer=TEST_METRIC, metric_name=METRIC_NAME)
    if generate_latex:
        for value in VALUE_LIST:
            calculate_result(MODEL_LIST, MODEL_NAME_LIST, NEW_VALUE_LIST, value, METRIC_NAME)
