import csv
from sklearn.metrics import precision_score, recall_score, f1_score
import os

def calculate_metrics(csv_file):
    # 读取CSV文件
    true_labels = []
    pred_labels = []
    
    with open(csv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            # 判断Predict值是否为1或0，如果不是则赋值为2
            if row['Predict值'] not in ['0', '1']:
                pred_label = 2  # 将非0或1的值设为2
            else:
                # 将State值和Predict值转换为整数
                pred_label = int(row['Predict值'])
            
            true_label = int(row['State值'])
            
            true_labels.append(true_label)
            pred_labels.append(pred_label)
    
    # 计算准确率、precision、recall和f1，计算时将2作为一个额外类别
    accuracy = sum(1 for t, p in zip(true_labels, pred_labels) if t == p) / len(true_labels)
    
    # 计算precision时指定zero_division参数，包含类别2
    precision = precision_score(true_labels, pred_labels, labels=[0, 1],average='weighted', zero_division=0)
    recall = recall_score(true_labels, pred_labels, labels=[0, 1],average='weighted', zero_division=0)
    f1 = f1_score(true_labels, pred_labels, labels=[0, 1],average='weighted', zero_division=0)
    
    return accuracy, precision, recall, f1, true_labels, pred_labels

def calculate_overall_metrics(all_true_labels, all_pred_labels):
    # 计算整体的准确率、precision、recall和f1
    accuracy = sum(1 for t, p in zip(all_true_labels, all_pred_labels) if t == p) / len(all_true_labels)
    
    precision = precision_score(all_true_labels, all_pred_labels, labels=[0, 1],average='weighted', zero_division=0)
    recall = recall_score(all_true_labels, all_pred_labels, labels=[0, 1],average='weighted', zero_division=0)
    f1 = f1_score(all_true_labels, all_pred_labels, labels=[0, 1],average='weighted', zero_division=0)
    
    return accuracy, precision, recall, f1

# 示例调用
input_dir = ''
json_files = ['math.csv','chemistry.csv','biology.csv','physics.csv']  # 假设有多个CSV文件
all_true_labels = []
all_pred_labels = []

for json_file in json_files:
    file_path = os.path.join(input_dir, json_file)
    base_name = os.path.splitext(json_file)[0]

    accuracy, precision, recall, f1, true_labels, pred_labels = calculate_metrics(file_path)
    
    # 添加到整体标签列表
    all_true_labels.extend(true_labels)
    all_pred_labels.extend(pred_labels)
    
    print(f'============================={base_name}=============')
    print(f"准确率 (accuracy): {accuracy:.4f}")
    print(f"精确率 (precision): {precision:.4f}")
    print(f"召回率 (Recall): {recall:.4f}")
    print(f"F1值 (F1 Score): {f1:.4f}")

# 计算所有文件的总体指标
overall_accuracy, overall_precision, overall_recall, overall_f1 = calculate_overall_metrics(all_true_labels, all_pred_labels)

# 打印总体指标
print(f'=============================Overall======================')
print(f"总体准确率 (accuracy): {overall_accuracy:.4f}")
print(f"总体精确率 (precision): {overall_precision:.4f}")
print(f"总体召回率 (Recall): {overall_recall:.4f}")
print(f"总体F1值 (F1 Score): {overall_f1:.4f}")
