import argparse

from tqdm import tqdm

from src.dataset_utils.dataset import get_dataset

from src.wandb_uitls.WandbUtils import init_wandb_parser, close_wandb, get_wandb, wandb_log


def main(args):
    dataset = get_dataset(args.dataset, args.data_type)

    dataset.show_statistic_information()
    dr_samples = dataset.get_samples_for_default_reasoning(slice(None))

    print('开始运行...')
    prediction_path = args.prediction_path
    prediction = dataset.read_prediction(prediction_path)

    label2idx = {'T': 1, 'F': 0, 'M': 2}
    class_names = ['F', 'T', 'M']

    text2label = {
        'true': 1,
        'T' : 1,
        'false': 0,
        'F': 0,
        'maybe': 2,
        'M': 2
    }

    def calculate_f1(label, p_label):
        label_set = set(label)
        p_label_set = set(p_label)

        true_positive = len(label_set & p_label_set)
        false_positive = len(p_label_set - label_set)
        false_negative = len(label_set - p_label_set)

        precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
        recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0

        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        return precision, recall, f1

    total_precision = 0
    total_recall = 0
    total_f1 = 0
    correct_included_count = 0
    p_len = 0
    for idx, pred in tqdm(enumerate(prediction)):
        label = pred['label']
        p_label = pred['prediction'] if pred['prediction'] else []
        p_len += len(p_label)
        if len(set(label) & set(p_label)) > 0:
            correct_included_count += 1

        precision, recall, f1 = calculate_f1(label, p_label)

        total_precision += precision
        total_recall += recall
        total_f1 += f1

    # 计算平均值
    exact_average_precision = total_precision / len(prediction) * 100
    exact_average_recall = total_recall / len(prediction) * 100
    exact_average_f1 = total_f1 / len(prediction) * 100
    exact_correct_included_rate = correct_included_count / len(prediction) * 100
    p_len /= len(prediction)

    print(f"EMA-p EMA-r EMA-f1 E-C #p")
    print(f"{exact_average_precision:.1f} {exact_average_recall:.1f} {exact_average_f1:.1f} {exact_correct_included_rate:.1f} {p_len:.1f}")

    total_precision = 0
    total_recall = 0
    total_f1 = 0
    correct_included_count = 0
    for idx, pred in tqdm(enumerate(prediction)):
        label = [i.replace(' ', '').lower() for i in pred['label']]
        p_label = [i.replace(' ', '').lower() for i in pred['prediction']] if pred['prediction'] else []
        if len(set(label) & set(p_label)) > 0:
            correct_included_count += 1

        precision, recall, f1 = calculate_f1(label, p_label)

        total_precision += precision
        total_recall += recall
        total_f1 += f1

    # 计算平均值
    fuzzy_average_precision = total_precision / len(prediction) * 100
    fuzzy_average_recall = total_recall / len(prediction) * 100
    fuzzy_average_f1 = total_f1 / len(prediction) * 100
    fuzzy_correct_included_rate = correct_included_count / len(prediction) * 100
    print(f"FMA-p FMA-r FMA-f1 F-C")
    print(f"{fuzzy_average_precision:.1f} {fuzzy_average_recall:.1f} {fuzzy_average_f1:.1f} {fuzzy_correct_included_rate:.1f}")


    if args.wandb:
        wandb_log({
            'EMA-p': exact_average_precision,
            'EMA-r': exact_average_recall,
            'EMA-f1': exact_average_f1,
            'E-C': exact_correct_included_rate,
            'FMA-p': fuzzy_average_precision,
            'FMA-r': fuzzy_average_recall,
            'FMA-f1': fuzzy_average_f1,
            'F-C': fuzzy_correct_included_rate,
            '#p': p_len,
        })


    close_wandb()



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="llm predict labels")
    parser.add_argument('--dataset', type=str, help='The name of the dataset.', default='LabelClassification')
    parser.add_argument('--data_dir_path', type=str, help='The path of dataset')
    parser.add_argument('--data_type', type=str, help='The type of dataset', default='random_str_symbolic')
    parser.add_argument("--prediction_path", type=str, help="prediction path", required=True)
    parser.add_argument("--model", type=str, help="model name", required=True)

    init_wandb_parser(parser)

    args = parser.parse_args()

    if args.wandb:
        print('init wandb by args:', args)
        w = get_wandb(args)
        print('wandb:', w)

    print(args)
    main(args)