import json
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


# 定义简单的分类器模型
class MultiModalClassifier(nn.Module):
    def __init__(self):
        super(MultiModalClassifier, self).__init__()

        # 每个模态单独处理
        self.fc1_weighted = nn.Linear(512, 256)
        self.fc1_ast = nn.Linear(512, 256)
        self.fc1_cfg = nn.Linear(512, 256)

        # Dropout层
        self.dropout = nn.Dropout(p=0.3)

        # 融合后的全连接层
        self.fc2 = nn.Linear(256 * 3, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 1)

        # 激活函数
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_weighted, x_ast, x_cfg):
        # 每个模态的特征提取
        x_weighted = self.relu(self.fc1_weighted(x_weighted))
        x_ast = self.relu(self.fc1_ast(x_ast))
        x_cfg = self.relu(self.fc1_cfg(x_cfg))

        # 特征拼接
        x = torch.cat([x_weighted, x_ast, x_cfg], dim=1)

        # Dropout层
        x = self.dropout(x)

        # 全连接层
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)

        # 输出概率
        x = self.sigmoid(x)
        return x


class CustomDataset(Dataset):
    def __init__(self, data, attention_weights):
        self.data = data
        self.attention_weights = attention_weights

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        weighted_embedding = self.attention_weights[idx]
        # X = torch.tensor(weighted_embedding + item['ast_features'] + item['cfg_features'], dtype=torch.float32)
        if isinstance(weighted_embedding, float) and weighted_embedding == 0.0:
            weighted_embedding = np.zeros((1, 512), dtype=np.float32)
        else:
            weighted_embedding = np.array(weighted_embedding, dtype=np.float32)
            weighted_embedding = np.expand_dims(weighted_embedding, axis=0).astype(np.float32)

        ast_features = np.expand_dims(item['ast_features'], axis=0).astype(np.float32)
        cfg_features = np.expand_dims(item['cfg_features'], axis=0).astype(np.float32)

        # Create a tensor of shape (3, 512) by stacking the arrays along the first dimension
        X = torch.tensor(np.concatenate([weighted_embedding, ast_features, cfg_features], axis=0), dtype=torch.float32)
        y = torch.tensor(item['y'], dtype=torch.float32)
        return X, y


# class CustomDataset(Dataset):
#     def __init__(self, data, average_embeddings):
#         self.data = data
#         self.average_embeddings = average_embeddings
#
#     def __len__(self):
#         return len(self.data)
#
#     def __getitem__(self, idx):
#         item = self.data[idx]
#         average_embedding = self.average_embeddings[idx]
#
#         # 检查 weighted_embedding 是否为 NaN，并进行处理
#         if np.isnan(average_embedding).any():
#             average_embedding = np.zeros((1, 512), dtype=np.float32)
#         else:
#             average_embedding = np.expand_dims(average_embedding, axis=0).astype(np.float32)
#
#         # 将 average_embedding、ast_features 和 cfg_features 连接在一起
#         ast_features = np.expand_dims(item['ast_features'], axis=0).astype(np.float32)
#         cfg_features = np.expand_dims(item['cfg_features'], axis=0).astype(np.float32)
#         X = torch.tensor(np.concatenate([average_embedding, ast_features, cfg_features], axis=0), dtype=torch.float32)
#         y = torch.tensor(item['y'], dtype=torch.float32)
#         return X, y


# 读取数据文件
def read_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


# 提取 token 并生成频率图
def plot_token_frequency(data):
    all_tokens = []
    for item in data:
        all_tokens.extend(item['token'])

    # 处理特殊的token替换和合并
    processed_tokens = []

    for token in all_tokens:
        if token == 'math':
            processed_tokens.append('safe math')
        elif token == 'operations':
            processed_tokens.append('math operations')
        elif token != 'title' and token != 'safe':
            processed_tokens.append(token)

    token_freq = Counter(processed_tokens)
    #
    # # 只取前100个频率最高的单词
    # most_common_tokens = token_freq.most_common(80)
    # tokens, freqs = zip(*most_common_tokens)
    #
    # plt.figure(figsize=(36, 20))  # 调整图像大小
    #
    # # 设置字体大小
    # plt.bar(tokens, freqs)
    # plt.xlabel('Comments', fontsize=32)
    # plt.ylabel('Frequency', fontsize=32)
    # plt.title('Comments Token Frequency', fontsize=36)
    # plt.xticks(rotation=90, fontsize=28)
    # plt.yticks(fontsize=28)
    # plt.tight_layout()
    #
    # # 保存图像为 PDF 文件
    # plt.savefig('Comments_Frequency.pdf')
    # plt.show()

    return token_freq


def calculate_attention_weights(data, token_freq):
    # Ensure 'safe' token is in token_freq
    token_freq['safe'] = token_freq['safe math']  # Assign a frequency of 0 if not present
    token_freq['title'] = 1  # Assign a frequency of 0 if not present
    token_freq['math'] = token_freq['safe math']  # Assign a frequency of 0 if not present
    token_freq['operations'] = token_freq['math operations']

    token_to_weight = {token: freq / sum(token_freq.values()) for token, freq in token_freq.items()}

    attention_weights = []
    for item in data:
        token_weights = [token_to_weight[token] for token in item['token'] if token in token_to_weight]
        attention_weights.append(token_weights)

    return attention_weights


# 计算注意力加权的 embedding
def calculate_weighted_embedding(data, attention_weights):
    weighted_embeddings = []
    for item, weights in zip(data, attention_weights):
        token_embeddings = np.array(item['comments_feature'])
        weighted_embedding = np.dot(weights, token_embeddings)
        weighted_embeddings.append(weighted_embedding)

    return weighted_embeddings


def calculate_average_embedding(data):
    average_embeddings = []
    for item in data:
        token_embeddings = np.array(item['comments_feature'])
        average_embedding = np.mean(token_embeddings, axis=0)
        average_embeddings.append(average_embedding)
    return average_embeddings


# 准备 DataLoader
def prepare_dataloader(dataset, batch_size=1024, shuffle=True):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader


def train_model(model, dataloader, num_epochs=500, learning_rate=0.01):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.BCELoss()  # 二分类任务的标准损失函数 ; nn.MSELoss()  # 2分类， Cross-Entropy Loss ：多分类
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_losses = []
    min_loss = float('inf')
    # min_loss = 0.1200
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for inputs, labels in dataloader:
            # 解包 inputs
            weighted_embedding = inputs[:, 0, :]  # 第一个通道（加权嵌入）
            ast_features = inputs[:, 1, :]  # 第二个通道（AST特征）
            cfg_features = inputs[:, 2, :]  # 第三个通道（CFG特征）

            labels = labels.to(device)
            weighted_embedding, ast_features, cfg_features = weighted_embedding.to(device), ast_features.to(
                device), cfg_features.to(device)

            optimizer.zero_grad()
            outputs = model(weighted_embedding, ast_features, cfg_features)
            labels = labels.unsqueeze(1)
            # labels = labels.unsqueeze(1).unsqueeze(2).repeat(1, 3, 1)
            # 计算 outputs > 0.8 的数量
            num_high_outputs = (outputs > 0.8).sum().item()  # .item() 获取总数量
            num_high_labels = (labels == 1).sum().item()
            print("num_high_outputs", num_high_outputs)
            print("num_high_labels", num_high_labels)
            print("min_output", outputs.min().item())
            print("max_output", outputs.max().item())

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * weighted_embedding.size(0)

        epoch_loss /= len(dataloader.dataset)
        train_losses.append(epoch_loss)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

        if epoch_loss < min_loss:
            min_loss = epoch_loss
            best_model = model.state_dict()
            print("Model saved with min loss:", min_loss)

    torch.save({
        'model_state_dict': best_model,
        'optimizer_state_dict': optimizer.state_dict(),
        'min_loss': min_loss,
        'train_losses': train_losses
    }, 'one_factor_model/ast_cfg.pth')

    print("Training complete!")
    print("min loss", min_loss)
    return train_losses


# 主函数
def main():
    # 读取数据
    data = read_data('dataset/train.json')

    # 显示 token 频率图
    token_freq = plot_token_frequency(data)

    # 计算注意力权重
    attention_weights = calculate_attention_weights(data, token_freq)

    # 计算注意力加权的 embedding
    weighted_embeddings = calculate_weighted_embedding(data, attention_weights)
    # weighted_embeddings = calculate_average_embedding(data)

    # for da in data:
    #     da['ast_features'] = [0] * 512
    #     # da['cfg_features'] = [0] * 512
    #     # if len(da['token']) != 0:
    #     #     da['token'] = [0] * len(da['token'])
    #     # da['comments_feature'] = [0] * 512

    # 创建自定义数据集和数据加载器
    dataset = CustomDataset(data, weighted_embeddings)
    dataloader = prepare_dataloader(dataset)

    # 创建模型实例
    model = MultiModalClassifier()

    # 训练模型
    train_losses = train_model(model, dataloader)

    # 返回训练损失列表，以便进一步分析或可视化
    return train_losses


if __name__ == "__main__":
    main()
