import argparse
from tqdm import tqdm

from dsg.csv_loader import load_dsg_file

def show_result(
    dataset_category, 
    id2score,
):
    category_count = {name: 0 for name in dataset_category}
    category_count["all"] = 0
    category_sum = {name: 0.0 for name in dataset_category}
    category_sum["all"] = 0.0

    for item in id2score.keys(): 
        category_id = item.split("_")[0]
        category_count[category_id] = category_count[category_id] + 1
        category_count["all"] = category_count["all"] + 1
        category_sum[category_id] = category_sum[category_id] + id2score[item]
        category_sum["all"] = category_sum["all"] + id2score[item]
    
    for category_id in category_count.keys(): 
        print(f"Dataset category: {category_id}, items num: {category_count[category_id]}")
        print(f"Avg. score: {category_sum[category_id] / category_count[category_id]}")


def main(): 
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_file', type=str, default="./sdv21_new_output.csv", help='Directory of data files. ')
    args = parser.parse_args()

    assert args.data_file.endswith(".csv")
    dsg_data = load_dsg_file(args.data_file)

    raw_answer = dsg_data["answer"]
    raw_valid = dsg_data["valid"]

    rewritten_answer = dsg_data["rewritten_answer"]
    rewritten_valid = dsg_data["rewritten_valid"]

    id2score = {}
    id2score_rewritten = {}
    id2score_rewritten_refined = {}

    # sort by category
    dataset_category = []
    for item in tqdm(raw_answer.keys()): 
        ##### split each class #####
        if item.split("_")[0] not in dataset_category: 
            dataset_category.append(item.split("_")[0])
        
        qid2score = {}
        qid2answer = raw_answer[item]
        qid2valid = raw_valid[item]

        for qid, ans in qid2answer.items(): 
            qid2score[qid] = int(ans == "yes") * qid2valid[qid]
        
        qid2score_rewritten = {}
        qid2answer_rewritten = rewritten_answer[item]
        qid2valid_rewritten = rewritten_valid[item]

        for qid, ans in qid2answer_rewritten.items(): 
            qid2score_rewritten[qid] = int(ans == "yes") * qid2valid_rewritten[qid]
        
        # qid2score_rewritten_refined = {}
        # for qid in qid2score.keys(): 
        #     qid2score_rewritten_refined[qid] = 1 if qid2score[qid] == 1 else qid2score_rewritten[qid]
        
        id2score[item] = sum(qid2score.values()) / len(qid2score)
        id2score_rewritten[item] = sum(qid2score_rewritten.values()) / len(qid2score_rewritten)
        # id2score_rewritten_refined[item] = sum(qid2score_rewritten_refined.values()) / len(qid2score_rewritten_refined)
        id2score_rewritten_refined[item] = 1 if id2score[item] == 1 else id2score_rewritten[item]

        # show examples
        if id2score_rewritten_refined[item] == 1 and id2score[item] < 1: 
            print(item)
    
    ##### output info #####
    print("********************")
    show_result(dataset_category, id2score)
    print("********************")
    show_result(dataset_category, id2score_rewritten_refined)




if __name__ == "__main__": 
    main()

