
import argparse
import json

import pandas as pd
from sklearn.metrics import f1_score

from utils import config
from utils.utils import print_args


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default=config.model_name)
    parser.add_argument("--dataset_name", type=str, default=config.dataset_name)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=1000)
    parser.add_argument("--dataset_file_dic", type=str, default=config.dataset_file_dic)
    parser.add_argument("--save_file_dic", type=str, default=config.save_file_dic)
    parser.add_argument("--both_nb_cbn", type=str, default="true")

    args = parser.parse_args()
    print_args(args)
    return args

if __name__ == '__main__':
    args = parse_args()
    suffix = ''

    eval_path =  args.dataset_file_dic + "common2sense_human_annotation.csv"
    # load human evaluation data
    df_human = pd.read_csv(eval_path)
    factor_mapping_path = 'common2sense_qwen2.5-2b_0_1000_factors.json'
    with open(factor_mapping_path) as f:
        df_factor_mapping = json.load(f)
    df_factor_mapping_dict= dict()
    for i, df in enumerate(df_factor_mapping):
        df_factor_mapping_dict[df['scenario'] + df['statement']] = df
    input_path = f"{args.save_file_dic}{args.dataset_name}_{args.model_name}_{args.model_name.replace(':', '-')}_bn_compare{suffix}.json"
 
    print(f"Loading input from------------>: {input_path}")
    huamn_annotation_path = args.dataset_file_dic + "common2sense_human_annotation.json"
    print(f"Loading human annotation from------------>: {huamn_annotation_path}")
    with open(huamn_annotation_path) as f:
        df_human_json = json.load(f)

    df_human_json_dict = dict()
    for i, df in enumerate(df_human_json):
        df_human_json_dict[df['scenario'] + df['statement']] = df

    with open(input_path) as f:
        df_bn_cbn = json.load(f)

    df_bn_cbn_dict = dict()
    print('len:',len(df_bn_cbn))
    for i, res_df in enumerate(df_bn_cbn):
        if i < args.start or i >= args.end:
             continue
        factor_mapping = df_factor_mapping_dict[res_df['scenario'] + res_df['statement']]['factor_statement_mapping']
        if res_df['scenario'] + res_df['statement'] not in df_human_json_dict:
           continue

        else:
            df_human_json = df_human_json_dict[res_df['scenario'] + res_df['statement']]
        new_addition_sentence_probability = dict()
        for sent in df_human_json["additional_sentences"]:
            temp = {
                "mapped_values": {},
                "values_mapping":{},
                "nb": 0.5,
                "cbn": 0.5
            }
            if sent in res_df["results"].keys():
                # temp["mapped_values"] = res_df["results"][sent]["mapped_factors"]
                for key,value in res_df["para_prob"].items():
                    if key in res_df["results"][sent]["mapped_factors"]:
                        temp["mapped_values"][key]=value
                        temp["values_mapping"][ key]= factor_mapping[key]
                temp["nb"] = res_df["results"][sent]["nb"]
                if args.both_nb_cbn == "true":
                    temp["cbn"] = res_df["results"][sent]["cbn"]


            new_addition_sentence_probability[sent] = temp
        res_df["addition_sentence_probability"] = new_addition_sentence_probability
        df_bn_cbn_dict[res_df['scenario'] + res_df['statement']] = res_df

    # Performance analysis
    human_prediction = []
    count_unknown = 0
    nb_prediction = []
    cbn_prediction = []


    predict_res = []
    out_path = f"{args.save_file_dic}{args.dataset_name}_{args.model_name}_{args.model_name.replace(':', '-')}_bn_cbn_compare_res{suffix}_state.json"
    for i in range(len(df_human)):

        scenario = df_human['scenario'][i]
        statement_1 = df_human['statement_1'][i]
        statement_2 = df_human['statement_2'][i]
        condition_1 = df_human['sentence_1'][i]
        condition_2 = df_human['sentence_2'][i]
        statement = df_human['gold_statement'][i]
        res = {
        'scenario':scenario,
        'statement_1':statement_1,
        'statement_2':statement_2,
        'gold_statement':statement,
        'sentence1': {},
        'sentence2': {}
        }

        if scenario + statement_1 not in df_bn_cbn_dict:
            print(f"{scenario} {statement_1} not found")
            count_unknown += 1
            continue
        df_bn_cbn = df_bn_cbn_dict[scenario + statement_1]
        if condition_1 not in df_bn_cbn['addition_sentence_probability']:
            print(f"{scenario} {statement_1} {condition_1} not found")
            print(df_bn_cbn['addition_sentence_probability'])
        origin_probability_nb_1 = df_bn_cbn['addition_sentence_probability'][condition_1]["nb"]
        origin_probability_nb_2 = df_bn_cbn['addition_sentence_probability'][condition_2]["nb"]
        if args.both_nb_cbn == "true":
            origin_probability_cbn_1 = df_bn_cbn['addition_sentence_probability'][condition_1]["cbn"]
            origin_probability_cbn_2 = df_bn_cbn['addition_sentence_probability'][condition_2]["cbn"]

        if statement == statement_1:
            probability_nb_1 = origin_probability_nb_1
            probability_nb_2 = origin_probability_nb_2
            if args.both_nb_cbn == "true":
                probability_cbn_1 = origin_probability_cbn_1
                probability_cbn_2 = origin_probability_cbn_2
        elif statement == statement_2:
            probability_nb_1 = 1.0 - origin_probability_nb_1
            probability_nb_2 = 1.0 - origin_probability_nb_2
            if args.both_nb_cbn == "true":
                probability_cbn_1 = 1.0 - origin_probability_cbn_1
                probability_cbn_2 = 1.0 - origin_probability_cbn_2

        res['nb_probability_1'] = probability_nb_1
        res['nb_probability_2'] = probability_nb_2
        if args.both_nb_cbn == "true":
            res['cbn_probability_1'] = probability_cbn_1
            res['cbn_probability_2'] = probability_cbn_2

        mapped_factors_1 = df_bn_cbn['addition_sentence_probability'][condition_1]["mapped_values"]
        mapped_factors_2 = df_bn_cbn['addition_sentence_probability'][condition_2]["mapped_values"]
        factor_mapping_1 = df_bn_cbn['addition_sentence_probability'][condition_1]["values_mapping"]
        factor_mapping_2 = df_bn_cbn['addition_sentence_probability'][condition_2]["values_mapping"]

        res['sentence1']['sentence'] = condition_1
        res['sentence2']['sentence'] = condition_2

        res['sentence1']['mapped_values'] = mapped_factors_1
        res['sentence2']['mapped_values'] = mapped_factors_2
        res['sentence1']['values_mapping'] = factor_mapping_1
        res['sentence2']['values_mapping'] = factor_mapping_2

        if probability_nb_1 > probability_nb_2 :
            nb_prediction.append(1)
            res['nb_pred'] = '1'
        elif probability_nb_1 < probability_nb_2:
            nb_prediction.append(2)
            res['nb_pred'] = '2'
        else:
            nb_prediction.append(3)
            res['nb_pred'] = '3'
        if args.both_nb_cbn == "true":
            if probability_cbn_1 > probability_cbn_2:
                cbn_prediction.append(1)
                res['cbn_pred'] = '1'
            elif probability_cbn_1 < probability_cbn_2:
                cbn_prediction.append(2)
                res['cbn_pred'] = '2'
            else:
                cbn_prediction.append(3)
                res['cbn_pred'] = '3'

        gt_label = df_human['human_prediction'][i]
        human_prediction.append(gt_label)
        res['label'] = str(gt_label)
        res['nb_res'] = 'succeeded' if res['nb_pred'] == res['label'] else 'failed'
        if args.both_nb_cbn == "true":
            res['cbn_res'] = 'succeeded' if res['cbn_pred'] == res['label'] else 'failed'
        predict_res.append(res)
        with open(out_path, 'w', encoding='utf-8') as f:
            json.dump(predict_res, f, indent=4, ensure_ascii=False)
    # Evaluation for all models
    print("\nEvaluation Results:")
    print('unknow count:', count_unknown)
    print("output path:------------>", out_path)
    print("nb Individual F1 score for each category: ", f1_score(nb_prediction, human_prediction, average=None))
    print("nb Average F1 score: ", f1_score(nb_prediction, human_prediction, average='micro'))
    if args.both_nb_cbn == "true":
        print("cbn Individual F1 score for each category: ", f1_score(cbn_prediction, human_prediction, average=None))
        print("cbn Average F1 score: ", f1_score(cbn_prediction, human_prediction, average='micro'))
