import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import Counter
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt


# 定义简单分类器模型
# class SimpleClassifier(nn.Module):
#     def __init__(self):
#         super(SimpleClassifier, self).__init__()
#         self.fc1 = nn.Linear(512, 256)
#         self.dropout1 = nn.Dropout(p=0.1)
#         self.fc2 = nn.Linear(256, 128)
#         self.dropout2 = nn.Dropout(p=0.1)
#         self.fc3 = nn.Linear(128, 1)
#         self.sigmoid = nn.Sigmoid()
#
#     def forward(self, x):
#         x = torch.relu(self.fc1(x))
#         x = self.dropout1(x)
#         x = torch.relu(self.fc2(x))
#         x = self.dropout2(x)
#         x = self.fc3(x)
#         x = self.sigmoid(x)
#         return x

class MultiModalClassifier(nn.Module):
    def __init__(self):
        super(MultiModalClassifier, self).__init__()

        # 每个模态单独处理
        self.fc1_weighted = nn.Linear(512, 256)
        self.fc1_ast = nn.Linear(512, 256)
        self.fc1_cfg = nn.Linear(512, 256)

        # Dropout层
        self.dropout = nn.Dropout(p=0.3)

        # 融合后的全连接层
        self.fc2 = nn.Linear(256 * 3, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 1)

        # 激活函数
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_weighted, x_ast, x_cfg):
        # 每个模态的特征提取
        x_weighted = self.relu(self.fc1_weighted(x_weighted))
        x_ast = self.relu(self.fc1_ast(x_ast))
        x_cfg = self.relu(self.fc1_cfg(x_cfg))

        # 特征拼接
        x = torch.cat([x_weighted, x_ast, x_cfg], dim=1)

        # Dropout层
        x = self.dropout(x)

        # 全连接层
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)

        # 输出概率
        x = self.sigmoid(x)
        return x


# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, data, attention_weights):
        self.data = data
        self.attention_weights = attention_weights

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        weighted_embedding = self.attention_weights[idx]
        if isinstance(weighted_embedding, float) and weighted_embedding == 0.0:
            weighted_embedding = np.zeros((1, 512), dtype=np.float32)
        else:
            weighted_embedding = np.array(weighted_embedding, dtype=np.float32)
            weighted_embedding = np.expand_dims(weighted_embedding, axis=0).astype(np.float32)

        ast_features = np.expand_dims(item['ast_features'], axis=0).astype(np.float32)
        cfg_features = np.expand_dims(item['cfg_features'], axis=0).astype(np.float32)

        X = torch.tensor(np.concatenate([weighted_embedding, ast_features, cfg_features], axis=0), dtype=torch.float32)
        y = torch.tensor(item['y'], dtype=torch.float32)
        return X, y


# 读取数据文件
def read_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


# 提取 token 并生成频率图
def plot_token_frequency(data):
    all_tokens = []
    for item in data:
        all_tokens.extend(item['token'])

    processed_tokens = []
    for token in all_tokens:
        if token == 'math':
            processed_tokens.append('safe math')
        elif token == 'operations':
            processed_tokens.append('math operations')
        elif token != 'title' and token != 'safe':
            processed_tokens.append(token)

    token_freq = Counter(processed_tokens)
    return token_freq


# 计算注意力权重
def calculate_attention_weights(data, token_freq):
    token_freq['safe'] = token_freq['safe math']
    token_freq['title'] = 1
    token_freq['math'] = token_freq['safe math']
    token_freq['operations'] = token_freq['math operations']

    token_to_weight = {token: freq / sum(token_freq.values()) for token, freq in token_freq.items()}

    attention_weights = []
    for item in data:
        token_weights = [token_to_weight[token] for token in item['token'] if token in token_to_weight]
        attention_weights.append(token_weights)

    return attention_weights


#
# 计算注意力加权的 embedding
def calculate_weighted_embedding(data, attention_weights):
    weighted_embeddings = []
    for item, weights in zip(data, attention_weights):
        token_embeddings = np.array(item['comments_feature'])
        weighted_embedding = np.dot(weights, token_embeddings)
        weighted_embeddings.append(weighted_embedding)

    return weighted_embeddings


def calculate_average_embedding(data):
    average_embeddings = []
    for item in data:
        token_embeddings = np.array(item['comments_feature'])
        average_embedding = np.mean(token_embeddings, axis=0)
        average_embeddings.append(average_embedding)
    return average_embeddings


# 准备 DataLoader
def prepare_dataloader(dataset, batch_size=1024, shuffle=True):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader


# 评估模型
def evaluate_model(model, dataloader, criterion):
    model.eval()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    epoch_loss = 0.0
    all_labels = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            # inputs, labels = inputs.to(device), labels.to(device)
            # 解包 inputs
            weighted_embedding = inputs[:, 0, :]  # 第一个通道（加权嵌入）
            ast_features = inputs[:, 1, :]  # 第二个通道（AST特征）
            cfg_features = inputs[:, 2, :]  # 第三个通道（CFG特征）
            labels = labels.to(device)
            weighted_embedding, ast_features, cfg_features = weighted_embedding.to(device), ast_features.to(
                device), cfg_features.to(device)

            outputs = model(weighted_embedding, ast_features, cfg_features)

            labels = labels.unsqueeze(1)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item() * inputs.size(0)

            probabilities = outputs.cpu().numpy()
            predictions = (probabilities > 0.95).astype(int)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions)
            all_probabilities.extend(probabilities)

    epoch_loss /= len(dataloader.dataset)

    all_labels = np.array(all_labels).flatten()
    all_predictions = np.array(all_predictions).flatten()
    all_probabilities = np.array(all_probabilities).flatten()

    # 统计正类样本数量
    positive_label_count = np.sum(all_labels == 1)
    print(f"Number of positive labels (label=1): {positive_label_count}; all label num : {len(all_labels)}")

    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average='binary')
    predicted_positive_labels = all_predictions[all_predictions == 1]
    print(f"number of positive labels (label=1): {len(predicted_positive_labels)}")
    recall = recall_score(all_labels, all_predictions, average='binary')
    f1 = f1_score(all_labels, all_predictions, average='binary')
    auc = roc_auc_score(all_labels, all_probabilities)

    # 绘制 ROC 曲线
    fpr, tpr, _ = roc_curve(all_labels, all_probabilities)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    # 设置字体大小
    plt.xlabel('False Positive Rate', fontsize=18)
    plt.ylabel('True Positive Rate', fontsize=18)
    plt.title('Receiver Operating Characteristic', fontsize=20)
    plt.legend(loc="lower right", fontsize=14)
    # 设置坐标轴刻度字体大小
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    plt.savefig('roc_curve.pdf')
    plt.show()

    return epoch_loss, accuracy, precision, recall, f1, auc


# 主函数
def main():
    # 读取数据
    evaluate_data = read_data('dataset/evaluate.json')

    # 显示 token 频率图
    token_freq = plot_token_frequency(evaluate_data)

    # 创建自定义数据集和数据加载器
    evaluate_attention_weights = calculate_attention_weights(evaluate_data, token_freq)
    # evaluate_weighted_embeddings = calculate_average_embedding(evaluate_data)

    evaluate_weighted_embeddings = calculate_weighted_embedding(evaluate_data, evaluate_attention_weights)
    #
    # for data in evaluate_data:
    #     data['ast_features'] = [0] * 512
        # data['cfg_features']= [0] * 512
        # data['token'] = [0] * len(data['token'])
        # data['comments_feature'] = [0] * 512

    evaluate_dataset = CustomDataset(evaluate_data, evaluate_weighted_embeddings)
    evaluate_dataloader = prepare_dataloader(evaluate_dataset, shuffle=False)

    # 创建模型实例
    model = MultiModalClassifier()

    # 加载最佳模型
    checkpoint = torch.load('best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    # 定义损失函数
    criterion = nn.MSELoss()

    # 评估模型
    eval_loss, accuracy, precision, recall, f1, auc = evaluate_model(model, evaluate_dataloader, criterion)

    print(f"Evaluation Loss: {eval_loss:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"AUC: {auc:.4f}")

    return eval_loss, accuracy, precision, recall, f1, auc


if __name__ == "__main__":
    main()
