import json

def regression(task_type, regress_category):

    load_path = f"ckwise_results_hm/stats_{task_type}.jsonl"

    with open(load_path, "r", encoding="utf-8") as f:
        data_list = [json.loads(line) for line in f]

    category_list = []
    for data_dict in data_list:
        category = data_dict["category"]
        if category not in category_list:
            category_list.append(category)

    return_dict = {}

    if regress_category == "ALL":
        category_list = category_list
    else:
        category_list = [regress_category]
    
    for category in category_list:

        x = []
        y = []



        for data_index, data_dict in enumerate(data_list):
            if data_dict["category"] == category:
                evaluation_results = data_dict["evaluation_results"]
                checklist = data_dict["checklist"]
                x.append(evaluation_results[:len(checklist)])

                x[-1] = [score/4 for score in x[-1]]

                y.append(evaluation_results[-1])

        import numpy as np
        from scipy.optimize import minimize

        X = np.array(x)
        y = np.array(y)

        def loss_function(w, X, y):
            y_pred = X @ w
            return np.mean((y - y_pred) ** 2)

        initial_w = np.zeros(X.shape[1])

        bounds = [(0.5, None) for _ in range(X.shape[1])]

        result = minimize(loss_function, initial_w, args=(X, y), bounds=bounds)
        

        optimal_w = result.x

        y_pred = X @ optimal_w
        r2 = 1 - np.sum((y - y_pred) ** 2) / np.sum((y - np.mean(y)) ** 2)

        optimal_w = optimal_w / np.sum(optimal_w) * 100

        return_dict[category] = optimal_w

    return return_dict

def compute_overall_score(model_name, task_type, length=None):
    weight_dict = regression(task_type, regress_category="ALL")


    if length is not None:
        ckwise_results_path = f"ckwise_results/{model_name}/{task_type}_{length}_results.jsonl"
    else:
        ckwise_results_path = f"ckwise_results/{model_name}/{task_type}_results.jsonl"

    with open(ckwise_results_path, "r", encoding="utf-8") as f:
        data_list = [json.loads(line) for line in f]
    
    overall_score = 0

    if length is not None:
        data_path = f"data/length_constrained_data/{task_type}_{length}.jsonl"
    else:
        data_path = f"data/{task_type}.jsonl"

    id_dict = {}

    with open(data_path, "r", encoding="utf-8") as f:
        hl_data_list = [json.loads(line) for line in f]
    
    for data_dict in hl_data_list:
        id_dict[data_dict["id"]] = data_dict


    for data_dict in data_list:
        checklists = id_dict[data_dict["id"]]["checklists"]

        checklist_wise_evaluation = data_dict["checklist_wise_evaluation"]
        assert len(checklists) == len(checklist_wise_evaluation)

        tmp_score = 0
        category_tmp = id_dict[data_dict["id"]]["subcategory"]

        weight = weight_dict[category_tmp]
        for checklist_index, checklist_dict in enumerate(checklist_wise_evaluation):
            tmp_score += checklist_dict["evaluation_score"] * weight[checklist_index]

        overall_score += tmp_score

    return overall_score / len(data_list)



def format_csv(model_name_list, task_type = "ALL"):
    import pandas as pd
    import numpy as np

    if task_type == "ALL":
        task_type_list = ["open_ended_qa", "summarization", "chat", "text_completion", "heuristic_text_generation"]
    else:
        task_type_list = [task_type]

    # save all the data into one file
    overall_score_dict = {}
    for task_type in task_type_list:
        task_type_wise_list = []
        for model_name in model_name_list:
            overall_score = compute_overall_score(model_name, task_type)
            task_type_wise_list.append((overall_score-75)*4)
        overall_score_dict[task_type] = task_type_wise_list
    
    # ensure that the column is task type list, row is model name list
    df = pd.DataFrame(overall_score_dict, index=model_name_list)
    df.to_csv("overall_score.csv")


if __name__ == "__main__":
    task_type = "ALL"
    category = "ALL"

    model_name_list = ["llama31_70b", "llama31_8b", "qwen2_7b", "qwen2_72b"]

    format_csv(model_name_list, task_type=task_type)