import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, roc_curve, auc, confusion_matrix
from torchmetrics.classification import BinaryF1Score
from torch.utils.tensorboard import SummaryWriter
from sklearn.utils import compute_class_weight
import torch.serialization

class cfg:
    SEED = 42
    num_global_features = 309
    embedding_dim = 256
    batch_size = 16
    test_size = 0.2
    n_heads = 16  # For Transformer and SCAM
    n_gat_heads = 8  # For GATConv
    n_hid = 256
    n_class = 2
    dropout = 0.3
    flag = 0
    fine_tune_epochs = 10
    learning_rate = 0.00005

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class DGSM_SCAM_GAT_Enhanced(nn.Module):
    def __init__(self, num_features, embedding_dim, n_hid, n_class, dropout, n_heads, n_gat_heads=8):
        super(DGSM_SCAM_GAT_Enhanced, self).__init__()
        self.dropout = dropout
        self.n_hid = n_hid
        self.seq_len = 100
        self.n_gat_heads = n_gat_heads

        # Embedding layer
        self.embedding = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim).to(device)
        nn.init.xavier_uniform_(self.embedding.weight)

        # Initial GAT layer
        self.gat1 = GATConv(embedding_dim, n_hid // n_gat_heads, heads=n_gat_heads, dropout=dropout).to(device)
        self.ln_gat1 = nn.LayerNorm(n_hid).to(device)

        # DGSM branch
        self.dgsm = DynamicGatedSequenceModule(n_hid, n_hid, kernel_size=5, num_heads=n_heads, dropout=dropout).to(device)

        # SCAM branch
        self.scam = SequenceContextAggregationModule(n_hid, n_hid, window_size=7, num_heads=n_heads, dropout=dropout).to(device)

        # Pre-fusion layer
        self.pre_fusion = nn.TransformerEncoderLayer(
            d_model=n_hid * 2, nhead=n_heads, dim_feedforward=n_hid * 4, dropout=dropout, batch_first=True
        ).to(device)
        self.fusion_proj = nn.Linear(n_hid * 2, n_hid).to(device)
        self.ln_pre_fusion = nn.LayerNorm(n_hid).to(device)
        nn.init.xavier_uniform_(self.fusion_proj.weight)

        # Second GAT layer
        self.gat2 = GATConv(n_hid, n_hid // n_gat_heads, heads=n_gat_heads, dropout=dropout).to(device)
        self.ln_gat2 = nn.LayerNorm(n_hid).to(device)

        # Graph pooling
        self.pool = global_mean_pool

        # Final classification layer
        self.fc = nn.Linear(n_hid, n_class).to(device)
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x, edge_index, batch):
        batch_size = batch.max().item() + 1

        # Reshape and embed
        x = x.reshape(batch_size, self.seq_len)
        x_embed = self.embedding(x)  # [B, 100, 256]

        # Initial GAT
        x_flat = x_embed.view(-1, x_embed.size(-1))  # [B * 100, 256]
        x_gat = self.gat1(x_flat, edge_index)  # [B * 100, 256]
        x_gat = self.ln_gat1(x_gat).view(batch_size, self.seq_len, -1)  # [B, 100, 256]

        # DGSM branch
        x_dgsm = self.dgsm(x_gat)

        # SCAM branch
        x_scam = self.scam(x_gat)

        # Pre-fusion
        x_fused_input = torch.cat([x_dgsm, x_scam], dim=-1)  # [B, 100, 512]
        x_fused = self.pre_fusion(x_fused_input)
        x_fused = self.fusion_proj(x_fused)  # [B, 100, 256]
        x_fused = self.ln_pre_fusion(x_fused)

        # Second GAT (replacing AGAM)
        x_flat = x_fused.view(-1, x_fused.size(-1))  # [B * 100, 256]
        x_gat2 = self.gat2(x_flat, edge_index)  # [B * 100, 256]
        x_gat2 = self.ln_gat2(x_gat2)  # [B * 100, 256]

        # Graph pooling
        x_pooled = self.pool(x_gat2, batch)  # [B, 256]

        # Classification
        x_out = self.fc(x_pooled)
        return F.log_softmax(x_out, dim=-1)

# Multi-head attention module (for fusion)
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, dropout=0.3):
        super(MultiHeadAttention, self).__init__()
        assert out_dim % num_heads == 0
        self.d_k = out_dim // num_heads
        self.num_heads = num_heads
        self.out_dim = out_dim

        self.W_q = nn.Linear(in_dim, out_dim, bias=False)
        self.W_k = nn.Linear(in_dim, out_dim, bias=False)
        self.W_v = nn.Linear(in_dim, out_dim, bias=False)
        self.W_o = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(out_dim)

        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.out_dim)
        output = self.W_o(context)
        output = self.ln(output)
        return output

# Dynamic Gated Sequence Module (DGSM)
class DynamicGatedSequenceModule(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size=5, num_heads=8, dropout=0.3):
        super(DynamicGatedSequenceModule, self).__init__()
        self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=kernel_size // 2)
        self.attn = nn.MultiheadAttention(embed_dim=output_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.gate = nn.Linear(output_dim * 2, output_dim)
        self.ln = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.xavier_uniform_(self.gate.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        x_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
        attn_output, _ = self.attn(x_conv, x_conv, x_conv)
        gate_input = torch.cat([x_conv, attn_output], dim=-1)
        gate = torch.sigmoid(self.gate(gate_input))
        x_out = gate * x_conv + (1 - gate) * attn_output
        x_out = self.ln(self.dropout(x_out))
        return x_out

# Sequence Context Aggregation Module (SCAM)
class SequenceContextAggregationModule(nn.Module):
    def __init__(self, input_dim, output_dim, window_size=7, num_heads=16, dropout=0.3):
        super(SequenceContextAggregationModule, self).__init__()
        self.local_attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.global_pool = nn.Linear(input_dim, output_dim)
        self.context = nn.Parameter(torch.randn(output_dim))
        self.ln = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)
        self.window_size = window_size

        nn.init.xavier_uniform_(self.global_pool.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        mask = torch.ones(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size + 1)
            mask[i, start:i + 1] = 0
        mask = mask.bool()
        local_out, _ = self.local_attn(x, x, x, attn_mask=mask)
        global_out = self.global_pool(local_out)
        context = self.context.unsqueeze(0).expand(batch_size, seq_len, -1)
        x_out = global_out + context
        x_out = self.ln(self.dropout(x_out))
        return x_out

torch.serialization.add_safe_globals([
    DGSM_SCAM_GAT_Enhanced,
    MultiHeadAttention,
    DynamicGatedSequenceModule,
    SequenceContextAggregationModule,
    nn.Embedding,
    nn.Linear,
    nn.Dropout,
    nn.LayerNorm,
    nn.LSTM,
    nn.TransformerEncoder,
    nn.TransformerEncoderLayer,
    nn.ModuleList,
    nn.MultiheadAttention,
    GATConv,
    global_mean_pool
])

model_path = './results/dgsm_scam_gat/models/dynamic_api_DGSM_SCAM_GAT_Improved.pth'

try:
    model = DGSM_SCAM_GAT_Enhanced(
        num_features=cfg.num_global_features,
        embedding_dim=cfg.embedding_dim,
        n_hid=cfg.n_hid,
        n_class=cfg.n_class,
        dropout=cfg.dropout,
        n_heads=cfg.n_heads,
        n_gat_heads=cfg.n_gat_heads
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint.state_dict(), strict=False)
    print(f"Pretrained model loaded from {model_path}")
except Exception as e:
    print(f"Error loading model: {e}")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint.state_dict(), strict=False)
    print(f"Pretrained model loaded with weights_only=False from {model_path}")

fine_tune_data = pd.read_csv("./data/mal_api_2019/merged_api_index_data.csv")
if 'Unnamed: 0' in fine_tune_data.columns:
    fine_tune_data = fine_tune_data.drop(columns=['Unnamed: 0'])
    print("Dropped 'Unnamed: 0' column from fine-tune data")

print("Fine-tune data shape:", fine_tune_data.shape)
if 'malware' not in fine_tune_data.columns:
    raise ValueError("Expected 'malware' column in the dataset")
if fine_tune_data.shape[1] != 101:
    raise ValueError(f"Expected 101 columns (100 features + 1 label), got {fine_tune_data.shape[1]}")
if not np.all(fine_tune_data['malware'].isin([0, 1])):
    raise ValueError("Expected 'malware' column to contain only 0 or 1")

x_fine_all = fine_tune_data.drop(['malware'], axis=1).values.astype(float)
y_fine_all = fine_tune_data['malware'].values.astype(int)

x_fine_all = np.clip(x_fine_all, 0, cfg.num_global_features - 1)
print(f"After clipping - Max value in x_fine_all: {x_fine_all.max()}, Min value in x_fine_all: {x_fine_all.min()}")

x_fine_train, x_fine_test, y_fine_train, y_fine_test = train_test_split(
    x_fine_all, y_fine_all, test_size=cfg.test_size, random_state=cfg.SEED, stratify=y_fine_all
)
print(f"Train set shape: {x_fine_train.shape}, Test set shape: {x_fine_test.shape}")
print("Train label distribution:", np.bincount(y_fine_train))
print("Test label distribution:", np.bincount(y_fine_test))

x_fine_train = torch.from_numpy(x_fine_train).long().to(device)
x_fine_test = torch.from_numpy(x_fine_test).long().to(device)
y_fine_train = torch.from_numpy(y_fine_train).long().to(device)
y_fine_test = torch.from_numpy(y_fine_test).long().to(device)

def get_edge_index(num_features, batch_size):
    edge_index_list = []
    for i in range(batch_size):
        offset = i * num_features
        edges = torch.tensor([[j, j + 1] for j in range(num_features - 1)], dtype=torch.long).t()
        edge_index_list.append(edges + offset)
    edge_index = torch.cat(edge_index_list, dim=1).to(device)
    return edge_index

def get_graph_data(x, y):
    graph_data_list = []
    edge_index = torch.tensor([[j, j + 1] for j in range(100 - 1)], dtype=torch.long).t().contiguous().to(device)
    for i in range(x.size(0)):
        node_features = x[i]
        label = y[i]
        graph_data = Data(x=node_features, edge_index=edge_index, y=label)
        graph_data_list.append(graph_data)
    return graph_data_list

fine_tune_train_data_list = get_graph_data(x_fine_train, y_fine_train)
fine_tune_test_data_list = get_graph_data(x_fine_test, y_fine_test)
fine_tune_train_loader = DataLoader(fine_tune_train_data_list, batch_size=cfg.batch_size, shuffle=True)
fine_tune_test_loader = DataLoader(fine_tune_test_data_list, batch_size=cfg.batch_size, shuffle=False)

# Warmup + Cosine Decay learning rate scheduler
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, total_epochs, max_lr, min_lr=1e-6, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.max_lr = max_lr
        self.min_lr = min_lr
        super(WarmupCosineScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            lr = [self.max_lr * (self.last_epoch + 1) / self.warmup_epochs for _ in self.base_lrs]
        else:
            decay_epochs = self.total_epochs - self.warmup_epochs
            cosine_decay = 0.5 * (1 + np.cos(np.pi * (self.last_epoch - self.warmup_epochs) / decay_epochs))
            lr = [self.min_lr + (self.max_lr - self.min_lr) * cosine_decay for _ in self.base_lrs]
        return lr


model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=1e-3)
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_fine_train.cpu().numpy()), y=y_fine_train.cpu().numpy())
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print("Class weights for fine-tuning:", class_weights)
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)
scheduler = WarmupCosineScheduler(optimizer, warmup_epochs=2, total_epochs=cfg.fine_tune_epochs, max_lr=cfg.learning_rate, min_lr=1e-6)

for epoch in range(cfg.fine_tune_epochs):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    grad_norm = 0
    for data in fine_tune_train_loader:
        data = data.to(device)
        batch_size = data.batch.max().item() + 1
        edge_index = get_edge_index(100, batch_size)
        optimizer.zero_grad()
        output = model(data.x, edge_index, data.batch)
        loss = criterion(output, data.y)
        loss.backward()
        grad_norm += torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
        optimizer.step()
        total_loss += loss.item() * batch_size
        _, predicted = output.max(dim=1)
        total_correct += (predicted == data.y).sum().item()
        total_samples += data.y.size(0)
    avg_loss = total_loss / total_samples
    train_accuracy = total_correct / total_samples
    avg_grad_norm = grad_norm / len(fine_tune_train_loader)
    print(f"Epoch {epoch+1}/{cfg.fine_tune_epochs}, Loss: {avg_loss:.8f}, Train Acc: {train_accuracy:.8f}, "
          f"Grad Norm: {avg_grad_norm:.4f}, LR: {scheduler.get_last_lr()[0]:.8f}")

    scheduler.step()

fine_tune_path = './results/dgsm_scam_gat/models/mal_api_2019_DGSM_SCAM_GAT_Improved.pth'
os.makedirs(os.path.dirname(fine_tune_path), exist_ok=True)
torch.save(model.state_dict(), fine_tune_path)
print(f"Fine-tuned model saved to {fine_tune_path}")

model.eval()
total_correct = 0
total_samples = 0
y_pred_proba = []
y_pred_proba_1 = []
y_true = []
y_pred_all = []
f1_metric = BinaryF1Score().to(device)

with torch.no_grad():
    for data in fine_tune_test_loader:
        data = data.to(device)
        batch_size = data.batch.max().item() + 1
        edge_index = get_edge_index(100, batch_size)
        output = model(data.x, edge_index, data.batch)
        _, predicted = output.max(dim=1)
        probabilities = F.softmax(output, dim=1)
        y_pred_proba_1.extend(probabilities[:, 1].cpu().numpy())
        y_pred_proba.extend(probabilities.cpu().numpy())
        y_pred_all.extend(predicted.cpu().numpy())  # 累积预测结果
        y_true.extend(data.y.cpu().numpy())
        total_correct += (predicted == data.y).sum().item()
        total_samples += data.y.size(0)
        f1_metric.update(predicted, data.y)

accuracy = total_correct / total_samples
precision = precision_score(y_true, y_pred_all, average='weighted')
recall = recall_score(y_true, y_pred_all, average='weighted')
f1 = f1_metric.compute()

print(f'Fine-tuned Accuracy on merged_api_index_data test set: {accuracy:.4f}')
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

writer = SummaryWriter(log_dir='./results/dgsm_scam_gat/runs/mal_api_2019_Fine-tuned')

writer.add_scalar('Fine-tuned/Test/Accuracy', accuracy, 0)
writer.add_scalar('Fine-tuned/Test/Precision', precision, 0)
writer.add_scalar('Fine-tuned/Test/Recall', recall, 0)
writer.add_scalar('Fine-tuned/Test/F1_Score', f1, 0)

fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba_1)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve - Fine-tuned on merged_api_index_data')
plt.legend(loc="lower right")
roc_path = './results/dgsm_scam_gat/evaluation/roc_curve_fine_tuned_merged_api.png'
plt.savefig(roc_path)
plt.show()
print(f"ROC curve saved to {roc_path}")

writer.add_figure('Fine-tuned/ROC Curve', plt.gcf(), global_step=0)

conf_matrix = confusion_matrix(y_true, y_pred_all, labels=range(cfg.n_class))
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=[f'Class {i}' for i in range(cfg.n_class)],
            yticklabels=[f'Class {i}' for i in range(cfg.n_class)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix - Fine-tuned on merged_api_index_data')
conf_matrix_path = './results/dgsm_scam_gat/evaluation/confusion_matrix_fine_tuned_merged_api.png'
plt.savefig(conf_matrix_path)
plt.show()
print(f"Confusion matrix saved to {conf_matrix_path}")

writer.add_figure('Fine-tuned/Confusion Matrix', plt.gcf(), global_step=0)
writer.close()

