import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from sklearn.utils import compute_class_weight
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, roc_curve, auc, confusion_matrix
from torchmetrics.classification import BinaryF1Score
from torch.utils.tensorboard import SummaryWriter
import json


class cfg:
    SEED = 42
    num_global_features = 309
    embedding_dim = 256
    batch_size = 16
    test_size = 0.2
    n_heads = 16  # For Transformer and SCAM
    n_gat_heads = 8  # For GATConv
    n_hid = 256
    n_class = 2
    dropout = 0.3
    flag = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data loading and preprocessing
if cfg.flag == 1:
    data = pd.read_csv("./data/dynamic_api_call_data/dynamic_api_call_sequence_per_malware_100_0_306.csv")
    data0 = data.drop(columns=['hash'], axis=1)
    data0 = data0.dropna(how='any', axis=0)
    data_dict = {}
    for label, group in data0.groupby('malware'):
        data_dict[label] = group.sample(min(group.shape[0], 20000))
    for label, group in data_dict.items():
        print(f"{label} Number of samples per class: {group.shape[0]}")
    data1 = pd.concat(data_dict.values()).sample(frac=1).reset_index(drop=True)
    data1.to_csv('./data/dynamic_api_call_data/dynamic_api_call_sequence_20000.csv')
else:
    data1 = pd.read_csv("./data/dynamic_api_call_data/dynamic_api_call_sequence_20000.csv")
    print("data1_shape:", data1.shape)

if 'Unnamed: 0' in data1.columns:
    data1 = data1.drop(columns=['Unnamed: 0'])
    print("Dropped 'Unnamed: 0' column")

x = data1.drop(['malware'], axis=1).values.astype(float)
y = data1['malware'].values.astype(int)
if x.shape[1] != 100:
    raise ValueError(f"Expected 100 features per sample, got {x.shape[1]}")
x = np.clip(x, 0, cfg.num_global_features - 1)
# print(f"After clipping - Max value in x: {x.max()}, Min value in x: {x.min()}")
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=cfg.test_size, random_state=cfg.SEED)

x_train = torch.from_numpy(x_train).long().to(device)
x_test = torch.from_numpy(x_test).long().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
y_test = torch.from_numpy(y_test).long().to(device)

print("Train label distribution:", torch.bincount(y_train).cpu().numpy())
print("Test label distribution:", torch.bincount(y_test).cpu().numpy())

def get_edge_index(num_features, batch_size):
    edge_index_list = []
    for i in range(batch_size):
        offset = i * num_features
        edges = torch.tensor([[j, j + 1] for j in range(num_features - 1)], dtype=torch.long).t()
        edge_index_list.append(edges + offset)
    edge_index = torch.cat(edge_index_list, dim=1).to(device)
    return edge_index

def get_graph_data(x, y):
    graph_data_list = []
    edge_index = torch.tensor([[j, j + 1] for j in range(100 - 1)], dtype=torch.long).t().contiguous().to(device)
    for i in range(x.size(0)):
        node_features = x[i]
        label = y[i]
        graph_data = Data(x=node_features, edge_index=edge_index, y=label)
        graph_data_list.append(graph_data)
    return graph_data_list

train_data_list = get_graph_data(x_train, y_train)
test_data_list = get_graph_data(x_test, y_test)
train_loader = DataLoader(train_data_list, batch_size=cfg.batch_size, shuffle=True)
test_loader = DataLoader(test_data_list, batch_size=cfg.batch_size, shuffle=False)

# Multi-head attention module (for fusion)
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, dropout=0.3):
        super(MultiHeadAttention, self).__init__()
        assert out_dim % num_heads == 0
        self.d_k = out_dim // num_heads
        self.num_heads = num_heads
        self.out_dim = out_dim

        self.W_q = nn.Linear(in_dim, out_dim, bias=False)
        self.W_k = nn.Linear(in_dim, out_dim, bias=False)
        self.W_v = nn.Linear(in_dim, out_dim, bias=False)
        self.W_o = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(out_dim)

        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.out_dim)
        output = self.W_o(context)
        output = self.ln(output)
        return output

# Dynamic Gated Sequence Module (DGSM)
class DynamicGatedSequenceModule(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size=5, num_heads=8, dropout=0.3):
        super(DynamicGatedSequenceModule, self).__init__()
        self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=kernel_size // 2)
        self.attn = nn.MultiheadAttention(embed_dim=output_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.gate = nn.Linear(output_dim * 2, output_dim)
        self.ln = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.xavier_uniform_(self.gate.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        x_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
        attn_output, _ = self.attn(x_conv, x_conv, x_conv)
        gate_input = torch.cat([x_conv, attn_output], dim=-1)
        gate = torch.sigmoid(self.gate(gate_input))
        x_out = gate * x_conv + (1 - gate) * attn_output
        x_out = self.ln(self.dropout(x_out))
        return x_out

# Sequence Context Aggregation Module (SCAM)
class SequenceContextAggregationModule(nn.Module):
    def __init__(self, input_dim, output_dim, window_size=7, num_heads=16, dropout=0.3):
        super(SequenceContextAggregationModule, self).__init__()
        self.local_attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.global_pool = nn.Linear(input_dim, output_dim)
        self.context = nn.Parameter(torch.randn(output_dim))
        self.ln = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)
        self.window_size = window_size

        nn.init.xavier_uniform_(self.global_pool.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        mask = torch.ones(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size + 1)
            mask[i, start:i + 1] = 0
        mask = mask.bool()
        local_out, _ = self.local_attn(x, x, x, attn_mask=mask)
        global_out = self.global_pool(local_out)
        context = self.context.unsqueeze(0).expand(batch_size, seq_len, -1)
        x_out = global_out + context
        x_out = self.ln(self.dropout(x_out))
        return x_out

# model: DGSM_SCAM_GAT_Enhanced
class DGSM_SCAM_GAT_Enhanced(nn.Module):
    def __init__(self, num_features, embedding_dim, n_hid, n_class, dropout, n_heads, n_gat_heads=8):
        super(DGSM_SCAM_GAT_Enhanced, self).__init__()
        self.dropout = dropout
        self.n_hid = n_hid
        self.seq_len = 100
        self.n_gat_heads = n_gat_heads

        # Embedding layer
        self.embedding = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim).to(device)
        nn.init.xavier_uniform_(self.embedding.weight)

        # Initial GAT layer
        self.gat1 = GATConv(embedding_dim, n_hid // n_gat_heads, heads=n_gat_heads, dropout=dropout).to(device)
        self.ln_gat1 = nn.LayerNorm(n_hid).to(device)

        # DGSM branch
        self.dgsm = DynamicGatedSequenceModule(n_hid, n_hid, kernel_size=5, num_heads=n_heads, dropout=dropout).to(device)

        # SCAM branch
        self.scam = SequenceContextAggregationModule(n_hid, n_hid, window_size=7, num_heads=n_heads, dropout=dropout).to(device)

        # Pre-fusion layer
        self.pre_fusion = nn.TransformerEncoderLayer(
            d_model=n_hid * 2, nhead=n_heads, dim_feedforward=n_hid * 4, dropout=dropout, batch_first=True
        ).to(device)
        self.fusion_proj = nn.Linear(n_hid * 2, n_hid).to(device)
        self.ln_pre_fusion = nn.LayerNorm(n_hid).to(device)
        nn.init.xavier_uniform_(self.fusion_proj.weight)

        # Second GAT layer (replacing AGAM)
        self.gat2 = GATConv(n_hid, n_hid // n_gat_heads, heads=n_gat_heads, dropout=dropout).to(device)
        self.ln_gat2 = nn.LayerNorm(n_hid).to(device)

        # Graph pooling
        self.pool = global_mean_pool

        # Final classification layer
        self.fc = nn.Linear(n_hid, n_class).to(device)
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x, edge_index, batch):
        batch_size = batch.max().item() + 1

        # Reshape and embed
        x = x.reshape(batch_size, self.seq_len)
        x_embed = self.embedding(x)  # [B, 100, 256]

        # Initial GAT
        x_flat = x_embed.view(-1, x_embed.size(-1))  # [B * 100, 256]
        x_gat = self.gat1(x_flat, edge_index)  # [B * 100, 256]
        x_gat = self.ln_gat1(x_gat).view(batch_size, self.seq_len, -1)  # [B, 100, 256]

        # DGSM branch
        x_dgsm = self.dgsm(x_gat)

        # SCAM branch
        x_scam = self.scam(x_gat)

        # Pre-fusion
        x_fused_input = torch.cat([x_dgsm, x_scam], dim=-1)  # [B, 100, 512]
        x_fused = self.pre_fusion(x_fused_input)
        x_fused = self.fusion_proj(x_fused)  # [B, 100, 256]
        x_fused = self.ln_pre_fusion(x_fused)

        # Second GAT (replacing AGAM)
        x_flat = x_fused.view(-1, x_fused.size(-1))  # [B * 100, 256]
        x_gat2 = self.gat2(x_flat, edge_index)  # [B * 100, 256]
        x_gat2 = self.ln_gat2(x_gat2)  # [B * 100, 256]

        # Graph pooling
        x_pooled = self.pool(x_gat2, batch)  # [B, 256]

        # Classification
        x_out = self.fc(x_pooled)
        return F.log_softmax(x_out, dim=-1)

# Compute class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train.cpu().numpy()), y=y_train.cpu().numpy())
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print("Class weights:", class_weights)

# Instantiate model
model = DGSM_SCAM_GAT_Enhanced(
    num_features=cfg.num_global_features,
    embedding_dim=cfg.embedding_dim,
    n_hid=cfg.n_hid,
    n_class=cfg.n_class,
    dropout=cfg.dropout,
    n_heads=cfg.n_heads,
    n_gat_heads=cfg.n_gat_heads
).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=1e-3)

# Warmup + Cosine Decay learning rate scheduler
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, total_epochs, max_lr, min_lr=1e-6, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.max_lr = max_lr
        self.min_lr = min_lr
        super(WarmupCosineScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            lr = [self.max_lr * (self.last_epoch + 1) / self.warmup_epochs for _ in self.base_lrs]
        else:
            decay_epochs = self.total_epochs - self.warmup_epochs
            cosine_decay = 0.5 * (1 + np.cos(np.pi * (self.last_epoch - self.warmup_epochs) / decay_epochs))
            lr = [self.min_lr + (self.max_lr - self.min_lr) * cosine_decay for _ in self.base_lrs]
        return lr

scheduler = WarmupCosineScheduler(optimizer, warmup_epochs=10, total_epochs=100, max_lr=0.00005, min_lr=1e-6)

# Training and evaluation
epochs = 100
train_losses, test_losses, train_acc, test_acc = [], [], [], []
best_acc = 0
patience = 20
patience_counter = 0

metrics_dict = {
    "epochs": [],
    "final_metrics": {}
}

writer = SummaryWriter(log_dir='./results/dgsm_scam_gat/runs/DGSM_SCAM_GAT_runs')
epoch_dir = './results/dgsm_scam_gat/dynamic_api_epoch'
os.makedirs(epoch_dir, exist_ok=True)

for epoch in range(epochs):
    model.train()
    train_total_loss = 0
    train_correct = 0
    train_total = 0
    grad_norm = 0
    for data in train_loader:
        data = data.to(device)
        batch_size = data.batch.max().item() + 1
        edge_index = get_edge_index(100, batch_size)
        optimizer.zero_grad()
        output = model(data.x, edge_index, data.batch)
        loss = criterion(output, data.y)
        loss.backward()
        grad_norm += torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
        optimizer.step()
        train_total_loss += loss.item() * batch_size
        _, predicted = output.max(dim=1)
        train_correct += (predicted == data.y).sum().item()
        train_total += data.y.size(0)
    train_total_loss /= len(train_loader)
    train_losses.append(train_total_loss)
    train_accuracy = train_correct / train_total
    train_acc.append(train_accuracy)
    avg_grad_norm = grad_norm / len(train_loader)

    model.eval()
    test_total_loss = 0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            batch_size = data.batch.max().item() + 1
            edge_index = get_edge_index(100, batch_size)
            output = model(data.x, edge_index, data.batch)
            loss = criterion(output, data.y)
            test_total_loss += loss.item() * batch_size
            _, predicted = output.max(dim=1)
            test_correct += (predicted == data.y).sum().item()
            test_total += data.y.size(0)
        test_accuracy = test_correct / test_total
        test_acc.append(test_accuracy)
        test_total_loss /= len(test_loader)
        test_losses.append(test_total_loss)

    scheduler.step()

    writer.add_scalar('Loss/Train', train_total_loss, epoch)
    writer.add_scalar('Loss/Test', test_total_loss, epoch)
    writer.add_scalar('Accuracy/Train', train_accuracy, epoch)
    writer.add_scalar('Accuracy/Test', test_accuracy, epoch)
    writer.add_scalar('Gradient/Norm', avg_grad_norm, epoch)

    print(f"Epoch {epoch}/{epochs}, Train Loss: {train_total_loss:.8f}, Test Loss: {test_total_loss:.8f}, "
          f"Train Acc: {train_accuracy:.8f}, Test Acc: {test_accuracy:.8f}, "
          f"LR: {scheduler.get_last_lr()[0]:.8f}, Grad Norm: {avg_grad_norm:.4f}")

    metrics_dict["epochs"].append({
        "epoch": epoch,
        "train_loss": train_total_loss,
        "test_loss": test_total_loss,
        "train_accuracy": train_accuracy,
        "test_accuracy": test_accuracy
    })

    epoch_path = os.path.join(epoch_dir, f'epoch{epoch}.pth')
    torch.save(model.state_dict(), epoch_path)

    if test_accuracy > best_acc:
        best_acc = test_accuracy

metrics_path = './results/dgsm_scam_gat/metrics/dgsm_scam_gat_metrics.json'
os.makedirs(os.path.dirname(metrics_path), exist_ok=True)
with open(metrics_path, 'w') as f:
    json.dump(metrics_dict, f, indent=4)
print(f"Epoch metrics saved to {metrics_path}")

# Plot loss and accuracy curves
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_acc, label='Train Acc')
plt.plot(test_acc, label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Select best model
best_epoch = np.argmax(test_acc)
print(f"The best epoch: {best_epoch}, Best Test accuracy: {test_acc[best_epoch]:.8f}")
best_epoch_path = os.path.join(epoch_dir, f'epoch{best_epoch}.pth')
model.load_state_dict(torch.load(best_epoch_path))

# Save model
save_path = './results/dgsm_scam_gat/models'
os.makedirs(save_path, exist_ok=True)
model_path = os.path.join(save_path, 'dynamic_api_DGSM_SCAM_GAT_Improved.pth')
torch.save(model, model_path)
print(f"Model saved to {model_path}")

# Evaluate model
model.eval()
total_correct = 0
total_samples = 0
y_pred_proba = []
y_pred_proba_1 = []
y_true = []
f1_metric = BinaryF1Score().to(device)

with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        batch_size = data.batch.max().item() + 1
        edge_index = get_edge_index(100, batch_size)
        output = model(data.x, edge_index, data.batch)
        _, predicted = output.max(dim=1)
        probabilities = F.softmax(output, dim=1)
        y_pred_proba_1.extend(probabilities[:, 1].cpu().numpy())
        y_pred_proba.extend(probabilities.cpu().numpy())
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true.extend(data.y.cpu().numpy())
        total_correct += (predicted == data.y).sum().item()
        total_samples += data.y.size(0)
        f1_metric.update(predicted, data.y)

accuracy = total_correct / total_samples
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_metric.compute()

print(f'Accuracy: {accuracy:.4f}')
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

metrics_dict["final_metrics"] = {
    "accuracy": float(accuracy),
    "precision": float(precision),
    "recall": float(recall),
    "f1_score": float(f1)
}
with open(metrics_path, 'w') as f:
    json.dump(metrics_dict, f, indent=4)
print(f"Final metrics saved to {metrics_path}")

writer.add_scalar('Final/Accuracy', accuracy, 0)
writer.add_scalar('Final/Precision', precision, 0)
writer.add_scalar('Final/Recall', recall, 0)
writer.add_scalar('Final/F1_Score', f1, 0)

# Plot ROC curve
fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba_1)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()

# Plot confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred, labels=range(cfg.n_class))
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=[f'Class {i}' for i in range(cfg.n_class)],
            yticklabels=[f'Class {i}' for i in range(cfg.n_class)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

writer.add_figure('Confusion Matrix', plt.gcf(), global_step=0)
writer.close()

