import torch.nn.modules.linear as nn_linear
from sympy import false
import torch.serialization
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, roc_curve, auc, confusion_matrix
from torchmetrics.classification import BinaryF1Score
from torch.utils.tensorboard import SummaryWriter

class cfg:
    SEED = 42
    num_global_features = 309
    embedding_dim = 256
    batch_size = 16
    test_size = 0.2
    n_heads = 16  # For Transformer and SCAM
    n_gat_heads = 8  # For GATConv
    n_hid = 256
    n_class = 2
    dropout = 0.4
    flag = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data = pd.read_csv("./data/mal_api_2019/merged_api_index_data.csv")
if 'Unnamed: 0' in data.columns:
    data = data.drop(columns=['Unnamed: 0'])
    print("Dropped 'Unnamed: 0' column")

print("Data shape:", data.shape)
if 'malware' not in data.columns:
    raise ValueError("Expected 'malware' column in the dataset")
if data.shape[1] != 101:
    raise ValueError(f"Expected 101 columns (100 features + 1 label), got {data.shape[1]}")
if not np.all(data['malware'].isin([0, 1])):
    raise ValueError("Expected 'malware' column to contain only 0 or 1")

x = data.drop(['malware'], axis=1).values.astype(float)
y = data['malware'].values.astype(int)
print("Feature value range before clipping - Max:", x.max(), "Min:", x.min())
x = np.clip(x, 0, cfg.num_global_features - 1)
print(f"After clipping - Max value in x: {x.max()}, Min value in x: {x.min()}")

x_test = torch.from_numpy(x).long().to(device)
y_test = torch.from_numpy(y).long().to(device)

print("Test label distribution:", torch.bincount(y_test).cpu().numpy())

def get_edge_index(num_features, batch_size):
    edge_index_list = []
    for i in range(batch_size):
        offset = i * num_features
        edges = torch.tensor([[j, j + 1] for j in range(num_features - 1)], dtype=torch.long).t()
        edge_index_list.append(edges + offset)
    edge_index = torch.cat(edge_index_list, dim=1).to(device)
    return edge_index

def get_graph_data(x, y):
    graph_data_list = []
    edge_index = torch.tensor([[j, j + 1] for j in range(100 - 1)], dtype=torch.long).t().contiguous().to(device)
    for i in range(x.size(0)):
        node_features = x[i]
        label = y[i]
        graph_data = Data(x=node_features, edge_index=edge_index, y=label)
        graph_data_list.append(graph_data)
    return graph_data_list

test_data_list = get_graph_data(x_test, y_test)
test_loader = DataLoader(test_data_list, batch_size=cfg.batch_size, shuffle=False)

# Multi-head attention module (for fusion)
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, dropout=0.4):
        super(MultiHeadAttention, self).__init__()
        assert out_dim % num_heads == 0
        self.d_k = out_dim // num_heads
        self.num_heads = num_heads
        self.out_dim = out_dim

        self.W_q = nn.Linear(in_dim, out_dim, bias=False)
        self.W_k = nn.Linear(in_dim, out_dim, bias=False)
        self.W_v = nn.Linear(in_dim, out_dim, bias=False)
        self.W_o = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(out_dim)

        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.out_dim)
        output = self.W_o(context)
        output = self.ln(output)
        return output

# Dynamic Gated Sequence Module (DGSM)
class DynamicGatedSequenceModule(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size=5, num_heads=8, dropout=0.4):
        super(DynamicGatedSequenceModule, self).__init__()
        self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=kernel_size // 2)
        self.attn = nn.MultiheadAttention(embed_dim=output_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.gate = nn.Linear(output_dim * 2, output_dim)
        self.ln = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.xavier_uniform_(self.gate.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        x_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
        attn_output, _ = self.attn(x_conv, x_conv, x_conv)
        gate_input = torch.cat([x_conv, attn_output], dim=-1)
        gate = torch.sigmoid(self.gate(gate_input))
        x_out = gate * x_conv + (1 - gate) * attn_output
        x_out = self.ln(self.dropout(x_out))
        return x_out

# Sequence Context Aggregation Module (SCAM)
class SequenceContextAggregationModule(nn.Module):
    def __init__(self, input_dim, output_dim, window_size=7, num_heads=16, dropout=0.4):
        super(SequenceContextAggregationModule, self).__init__()
        self.local_attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.global_pool = nn.Linear(input_dim, output_dim)
        self.context = nn.Parameter(torch.randn(output_dim))
        self.ln = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)
        self.window_size = window_size

        nn.init.xavier_uniform_(self.global_pool.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        mask = torch.ones(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size + 1)
            mask[i, start:i + 1] = 0
        mask = mask.bool()
        local_out, _ = self.local_attn(x, x, x, attn_mask=mask)
        global_out = self.global_pool(local_out)
        context = self.context.unsqueeze(0).expand(batch_size, seq_len, -1)
        x_out = global_out + context
        x_out = self.ln(self.dropout(x_out))
        return x_out

# New model: DGSM_SCAM_GAT_Enhanced
class DGSM_SCAM_GAT_Enhanced(nn.Module):
    def __init__(self, num_features, embedding_dim, n_hid, n_class, dropout, n_heads, n_gat_heads=8):
        super(DGSM_SCAM_GAT_Enhanced, self).__init__()
        self.dropout = dropout
        self.n_hid = n_hid
        self.seq_len = 100
        self.n_gat_heads = n_gat_heads

        # Embedding layer
        self.embedding = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim).to(device)
        nn.init.xavier_uniform_(self.embedding.weight)

        # Initial GAT layer
        self.gat1 = GATConv(embedding_dim, n_hid // n_gat_heads, heads=n_gat_heads, dropout=dropout).to(device)
        self.ln_gat1 = nn.LayerNorm(n_hid).to(device)

        # DGSM branch
        self.dgsm = DynamicGatedSequenceModule(n_hid, n_hid, kernel_size=5, num_heads=n_heads, dropout=dropout).to(device)

        # SCAM branch
        self.scam = SequenceContextAggregationModule(n_hid, n_hid, window_size=7, num_heads=n_heads, dropout=dropout).to(device)

        # Pre-fusion layer
        self.pre_fusion = nn.TransformerEncoderLayer(
            d_model=n_hid * 2, nhead=n_heads, dim_feedforward=n_hid * 4, dropout=dropout, batch_first=True
        ).to(device)
        self.fusion_proj = nn.Linear(n_hid * 2, n_hid).to(device)
        self.ln_pre_fusion = nn.LayerNorm(n_hid).to(device)
        nn.init.xavier_uniform_(self.fusion_proj.weight)

        # Second GAT layer (replacing AGAM)
        self.gat2 = GATConv(n_hid, n_hid // n_gat_heads, heads=n_gat_heads, dropout=dropout).to(device)
        self.ln_gat2 = nn.LayerNorm(n_hid).to(device)

        # Graph pooling
        self.pool = global_mean_pool

        # Final classification layer
        self.fc = nn.Linear(n_hid, n_class).to(device)
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x, edge_index, batch):
        batch_size = batch.max().item() + 1

        # Reshape and embed
        x = x.reshape(batch_size, self.seq_len)
        x_embed = self.embedding(x)  # [B, 100, 256]

        # Initial GAT
        x_flat = x_embed.view(-1, x_embed.size(-1))  # [B * 100, 256]
        x_gat = self.gat1(x_flat, edge_index)  # [B * 100, 256]
        x_gat = self.ln_gat1(x_gat).view(batch_size, self.seq_len, -1)  # [B, 100, 256]

        # DGSM branch
        x_dgsm = self.dgsm(x_gat)

        # SCAM branch
        x_scam = self.scam(x_gat)

        # Pre-fusion
        x_fused_input = torch.cat([x_dgsm, x_scam], dim=-1)  # [B, 100, 512]
        x_fused = self.pre_fusion(x_fused_input)
        x_fused = self.fusion_proj(x_fused)  # [B, 100, 256]
        x_fused = self.ln_pre_fusion(x_fused)

        # Second GAT (replacing AGAM)
        x_flat = x_fused.view(-1, x_fused.size(-1))  # [B * 100, 256]
        x_gat2 = self.gat2(x_flat, edge_index)  # [B * 100, 256]
        x_gat2 = self.ln_gat2(x_gat2)  # [B * 100, 256]

        # Graph pooling
        x_pooled = self.pool(x_gat2, batch)  # [B, 256]

        # Classification
        x_out = self.fc(x_pooled)
        return F.log_softmax(x_out, dim=-1)

torch.serialization.add_safe_globals([
    DynamicGatedSequenceModule,
    SequenceContextAggregationModule,
    DGSM_SCAM_GAT_Enhanced,
    MultiHeadAttention,
    nn.Embedding,
    nn.Linear,
    nn.Dropout,
    nn.LayerNorm,
    nn.LSTM,
    nn.TransformerEncoder,
    nn.TransformerEncoderLayer,
    nn.ModuleList,
    nn.MultiheadAttention,
    nn_linear.NonDynamicallyQuantizableLinear,
    GATConv
])

model = DGSM_SCAM_GAT_Enhanced(
    num_features=cfg.num_global_features,
    embedding_dim=cfg.embedding_dim,
    n_hid=cfg.n_hid,
    n_class=cfg.n_class,
    dropout=cfg.dropout,
    n_heads=cfg.n_heads,
).to(device)


model_path = './results/dgsm_scam_gat/models/dynamic_api_DGSM_SCAM_GAT_Improved.pth'
try:
    model = torch.load(model_path, map_location=device, weights_only=false)
    print(f"Model loaded from {model_path}")
except FileNotFoundError:
    print(f"Error: Model file not found at {model_path}")
    exit(1)
except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)

model.eval()

writer = SummaryWriter(log_dir='./results/dgsm_scam_gat/runs/mal_api_2019_validation')

total_correct = 0
total_samples = 0
y_pred_proba = []
y_pred_proba_1 = []
y_true = []
f1_metric = BinaryF1Score().to(device)

with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        batch_size = data.batch.max().item() + 1
        edge_index = get_edge_index(100, batch_size)
        output = model(data.x, edge_index, data.batch)
        _, predicted = output.max(dim=1)
        probabilities = F.softmax(output, dim=1)
        y_pred_proba_1.extend(probabilities[:, 1].cpu().numpy())
        y_pred_proba.extend(probabilities.cpu().numpy())
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true.extend(data.y.cpu().numpy())
        total_correct += (predicted == data.y).sum().item()
        total_samples += data.y.size(0)
        f1_metric.update(predicted, data.y)

accuracy = total_correct / total_samples
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_metric.compute().cpu().numpy()

print(f'Accuracy: {accuracy:.4f}')
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

results_dir = './results/dgsm_scam_gat/evaluation'
os.makedirs(results_dir, exist_ok=True)
results_file = os.path.join(results_dir, 'validation_results_on_merged_api.txt')
with open(results_file, 'w') as f:
    f.write(f"Test Results for DGSM_SCAM_AGAM_Improved on merged_api_index_data\n")
    f.write(f"Accuracy: {accuracy:.4f}\n")
    f.write(f"Precision: {precision:.4f}\n")
    f.write(f"Recall: {recall:.4f}\n")
    f.write(f"F1 Score: {f1:.4f}\n")
print(f"Evaluation results saved to {results_file}")


writer.add_scalar('Test/Accuracy', accuracy, 0)
writer.add_scalar('Test/Precision', precision, 0)
writer.add_scalar('Test/Recall', recall, 0)
writer.add_scalar('Test/F1_Score', f1, 0)

fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba_1)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
roc_path = os.path.join(results_dir, 'roc_curve_merged_api.png')
plt.savefig(roc_path)
plt.show()
print(f"ROC curve saved to {roc_path}")

writer.add_figure('ROC Curve', plt.gcf(), global_step=0)

conf_matrix = confusion_matrix(y_true, y_pred, labels=range(cfg.n_class))
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=[f'Class {i}' for i in range(cfg.n_class)],
            yticklabels=[f'Class {i}' for i in range(cfg.n_class)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
conf_matrix_path = os.path.join(results_dir, 'confusion_matrix_merged_api.png')
plt.savefig(conf_matrix_path)
plt.show()
print(f"Confusion matrix saved to {conf_matrix_path}")

writer.add_figure('Confusion Matrix', plt.gcf(), global_step=0)

writer.close()
