import argparse
import random

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm

from src.dataset_utils.dataset import get_dataset
import seaborn as sns

from src.wandb_uitls.WandbUtils import init_wandb_parser, close_wandb, get_wandb, wandb_log
import wandb


def main(args):
    dataset = get_dataset(args.dataset, args.data_type)

    dataset.show_statistic_information()
    dr_samples = dataset.get_samples_for_default_reasoning(slice(None))

    print('开始运行...')
    prediction_path = args.prediction_path
    prediction = dataset.read_prediction(prediction_path)

    # filter prediction
    len_range = [-1, 1000]
    def in_range(r, x):
        if r[0] <= x <= r[1]:
            return True
        return False

    prediction = [i for i in prediction if in_range(len_range, len(i['raw_response'].split(' ')))]
    print('filtered samples num:', len(prediction))

    label2idx = {'T': 0, 'F': 1, 'M': 2}
    class_names = ['F', 'T', 'M']

    text2label = {
        'true': label2idx['T'],
        'T' : label2idx['T'],
        'false': label2idx['F'],
        'F': label2idx['F'],
        'maybe': label2idx['M'],
        'M': label2idx['M'],
        None: label2idx['M'],
    }

    y = np.zeros([len(prediction), 3], int)
    y_hat = np.zeros_like(y)
    for idx, pred in tqdm(enumerate(prediction)):
        label = pred['label']
        p_label = pred['prediction']
        y[idx, text2label[label]] = 1

        if p_label in text2label:
            y_hat[idx, text2label[p_label]] = 1
        else:
            # if no prediction, set M
            y_hat[idx, text2label['M']] = 1

    weight = ((1 / (np.maximum(y.sum(0), 1))) * y).sum(1)*3
    # overall f1 for all labels
    print('calculating micro metric...')
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(y, y_hat, average='micro', zero_division=False)

    print('calculating weighted micro metric...')
    w_micro_precision, w_micro_recall, w_micro_f1, _ = precision_recall_fscore_support(y, y_hat, average='micro', zero_division=False, sample_weight=weight)

    # average f1 for each label
    print('calculating macro metric...')
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(y, y_hat, average='macro', zero_division=False)

    print('calculating weighted macro metric...')
    w_macro_precision, w_macro_recall, w_macro_f1, _ = precision_recall_fscore_support(y, y_hat, average='macro', zero_division=False, sample_weight=weight)

    # calculate the f1 for each sample, and finally use the average of f1 as the result
    print('calculating samples metric...')
    samples_precision, samples_recall, samples_f1, _ = precision_recall_fscore_support(y, y_hat, average='samples', zero_division=False)

    print('calculating weighted samples metric...')
    w_samples_precision, w_samples_recall, w_samples_f1, _ = precision_recall_fscore_support(y, y_hat, average='samples', zero_division=False, sample_weight=weight)

    # Print the metrics
    print(f"MA-p MA-r MA-f1 MI-p MI-r MI-f1 SA-p SA-r SA-f1")
    print(
        f"{macro_precision*100:.1f} {macro_recall*100:.1f} {macro_f1*100:.1f} {micro_precision*100:.1f} {micro_recall*100:.1f} {micro_f1*100:.1f} {samples_precision*100:.1f} {samples_recall*100:.1f} {samples_f1*100:.1f}")

    # Print the metrics
    print(f"wMA-p wMA-r wMA-f1 wMI-p wMI-r wMI-f1 wSA-p wSA-r wSA-f1")
    print(
        f"{w_macro_precision*100:.1f} {w_macro_recall*100:.1f} {w_macro_f1*100:.1f} {w_micro_precision*100:.1f} {w_micro_recall*100:.1f} {w_micro_f1*100:.1f} {w_samples_precision*100:.1f} {w_samples_recall*100:.1f} {w_samples_f1*100:.1f}")

    y_true = np.argmax(y, axis=1)
    y_pred = np.argmax(y_hat, axis=1)

    # 计算混淆矩阵
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])

    # 输出混淆矩阵
    print("Confusion Matrix:")
    print(cm)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    x = []
    y = []
    colors = []
    idx2col = ['red', 'green', 'blue']

    # 遍历所有预测
    for idx, pred in tqdm(enumerate(prediction)):
        # raw_response_length = len(pred['raw_response'].split(' '))
        sample_id = dataset.get_index_by_id(pred['Source_ID'])
        raw_response_length = len(dr_samples[sample_id]['facts']+dr_samples[sample_id]['rules'])
        p_label = pred['prediction']

        if p_label in text2label:
            p_label = text2label[p_label]
        else:
            p_label = text2label['M']

        t_label = text2label[pred['label']]
        label_value = t_label

        # 如果预测正确，标记为红色，否则为蓝色
        x.append(0.1*random.gauss(0, 1))

        colors.append(idx2col[p_label])
        y.append(label_value+ 0.1*random.gauss(0, 1))


    # 绘制散点图
    plt.figure(figsize=(10, 6))
    plt.scatter(x, y, c=colors, alpha=0.5)
    plt.title('Scatter Plot of raw_response Length vs Label Value')
    plt.xlabel('raw_response Length')
    plt.ylabel('Label Value (T=1, M=2, F=0)')
    plt.grid(True)
    plt.show()

    if args.wandb:
        wandb_log({
            'MA-p': macro_precision*100,
            'MA-r': macro_recall*100,
            'MA-f1': macro_f1*100,
            'MI-p': micro_precision*100,
            'MI-r': micro_recall*100,
            'MI-f1': micro_f1*100,
            'SA-p': samples_precision*100,
            'SA-r': samples_recall*100,
            'SA-f1': samples_f1 * 100,
            'wMA-p': w_macro_precision * 100,
            'wMA-r': w_macro_recall * 100,
            'wMA-f1': w_macro_f1 * 100,
            'wMI-p': w_micro_precision * 100,
            'wMI-r': w_micro_recall * 100,
            'wMI-f1': w_micro_f1 * 100,
            'wSA-p': w_samples_precision * 100,
            'wSA-r': w_samples_recall * 100,
            'wSA-f1': w_samples_f1 * 100,
            '#pre_label': (y_hat.sum() / y_hat.shape[0]),
            "conf_mat": wandb.plot.confusion_matrix(probs=None,
                                                    y_true=y_true, preds=y_pred,
                                                    class_names=class_names)
        })


    close_wandb()



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="llm predict labels")
    parser.add_argument('--dataset', type=str, help='The name of the dataset.', default='LabelClassification')
    parser.add_argument('--data_dir_path', type=str, help='The path of dataset')
    parser.add_argument('--data_type', type=str, help='The type of dataset', default='random_str_symbolic')
    parser.add_argument("--prediction_path", type=str, help="prediction path", required=True)
    parser.add_argument("--model", type=str, help="model name", required=True)

    init_wandb_parser(parser)

    args = parser.parse_args()

    if args.wandb:
        print('init wandb by args:', args)
        w = get_wandb(args)
        print('wandb:', w)

    print(args)
    main(args)