import argparse

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm

from src.dataset_utils.dataset import get_dataset
import seaborn as sns

from src.wandb_uitls.WandbUtils import init_wandb_parser, close_wandb, get_wandb, wandb_log
import wandb


def main(args):
    dataset = get_dataset(args.dataset, args.data_type)

    dataset.show_statistic_information()
    dr_samples = dataset.get_samples_for_default_reasoning(slice(None))

    print('开始运行...')
    prediction_path = args.prediction_path
    prediction = dataset.read_prediction(prediction_path)

    label2idx = {'T': 1, 'F': 0, 'M': 2}
    class_names = ['F', 'T', 'M']

    text2label = {
        'true': 1,
        'T' : 1,
        'false': 0,
        'F': 0,
        'maybe': 2,
        'M': 2
    }

    y = np.zeros([len(prediction), 3], int)
    y_hat = np.zeros_like(y)
    for idx, pred in tqdm(enumerate(prediction)):
        label = pred['label']
        p_label = pred['prediction']
        y[idx, text2label[label]] = 1

        if p_label in text2label:
            y_hat[idx, text2label[p_label]] = 1
        else:
            # if no prediction, set M
            y_hat[idx, text2label['M']] = 1

    weight = ((1 / y.sum(0)) * y).sum(1)*3
    # overall f1 for all labels
    print('calculating micro metric...')
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(y, y_hat, average='micro', zero_division=False)

    print('calculating weighted micro metric...')
    w_micro_precision, w_micro_recall, w_micro_f1, _ = precision_recall_fscore_support(y, y_hat, average='micro', zero_division=False, sample_weight=weight)

    # average f1 for each label
    print('calculating macro metric...')
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(y, y_hat, average='macro', zero_division=False)

    print('calculating weighted macro metric...')
    w_macro_precision, w_macro_recall, w_macro_f1, _ = precision_recall_fscore_support(y, y_hat, average='macro', zero_division=False, sample_weight=weight)

    # calculate the f1 for each sample, and finally use the average of f1 as the result
    print('calculating samples metric...')
    samples_precision, samples_recall, samples_f1, _ = precision_recall_fscore_support(y, y_hat, average='samples', zero_division=False)

    print('calculating weighted samples metric...')
    w_samples_precision, w_samples_recall, w_samples_f1, _ = precision_recall_fscore_support(y, y_hat, average='samples', zero_division=False, sample_weight=weight)

    # Print the metrics
    print(f"MA-p MA-r MA-f1 MI-p MI-r MI-f1 SA-p SA-r SA-f1")
    print(
        f"{macro_precision*100:.1f} {macro_recall*100:.1f} {macro_f1*100:.1f} {micro_precision*100:.1f} {micro_recall*100:.1f} {micro_f1*100:.1f} {samples_precision*100:.1f} {samples_recall*100:.1f} {samples_f1*100:.1f}")

    # Print the metrics
    print(f"wMA-p wMA-r wMA-f1 wMI-p wMI-r wMI-f1 wSA-p wSA-r wSA-f1")
    print(
        f"{w_macro_precision*100:.1f} {w_macro_recall*100:.1f} {w_macro_f1*100:.1f} {w_micro_precision*100:.1f} {w_micro_recall*100:.1f} {w_micro_f1*100:.1f} {w_samples_precision*100:.1f} {w_samples_recall*100:.1f} {w_samples_f1*100:.1f}")

    y_true = np.argmax(y, axis=1)
    y_pred = np.argmax(y_hat, axis=1)

    # 计算混淆矩阵
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])

    # 输出混淆矩阵
    print("Confusion Matrix:")
    print(cm)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['F', 'T', 'M'], yticklabels=['F', 'T', 'M'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    if args.wandb:
        wandb_log({
            'MA-p': macro_precision*100,
            'MA-r': macro_recall*100,
            'MA-f1': macro_f1*100,
            'MI-p': micro_precision*100,
            'MI-r': micro_recall*100,
            'MI-f1': micro_f1*100,
            'SA-p': samples_precision*100,
            'SA-r': samples_recall*100,
            'SA-f1': samples_f1 * 100,
            'wMA-p': w_macro_precision * 100,
            'wMA-r': w_macro_recall * 100,
            'wMA-f1': w_macro_f1 * 100,
            'wMI-p': w_micro_precision * 100,
            'wMI-r': w_micro_recall * 100,
            'wMI-f1': w_micro_f1 * 100,
            'wSA-p': w_samples_precision * 100,
            'wSA-r': w_samples_recall * 100,
            'wSA-f1': w_samples_f1 * 100,
            '#pre_label': (y_hat.sum() / y_hat.shape[0]),
            "conf_mat": wandb.plot.confusion_matrix(probs=None,
                                                    y_true=y_true, preds=y_pred,
                                                    class_names=class_names)
        })


    close_wandb()



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="llm predict labels")
    parser.add_argument('--dataset', type=str, help='The name of the dataset.', default='LabelClassification')
    parser.add_argument('--data_dir_path', type=str, help='The path of dataset')
    parser.add_argument('--data_type', type=str, help='The type of dataset', default='random_str_symbolic')
    parser.add_argument("--prediction_path", type=str, help="prediction path", required=True)
    parser.add_argument("--model", type=str, help="model name", required=True)

    init_wandb_parser(parser)

    args = parser.parse_args()

    if args.wandb:
        print('init wandb by args:', args)
        w = get_wandb(args)
        print('wandb:', w)

    print(args)
    main(args)