import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from transformers import ViTModel, ViTImageProcessor
import cv2
import numpy as np
from torch.amp import GradScaler, autocast
from sklearn.model_selection import train_test_split
import pickle
import pandas as pd
from collections import Counter
from sklearn.metrics import precision_score, recall_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassF1Score
import os
from torchvision import transforms
import json
from torchviz import make_dot

# Configuration class
class cfg:
    SEED = 42
    batch_size = 8
    test_size = 0.2
    max_seq_len = 2048  # Match wavelet sequence length
    class_num = 9
    num_epochs = 100

# Instruction set
instruction_set = {
    'mov', 'movsx', 'movzx', 'push', 'pop', 'lea', 'xchg', 'cmpxchg', 'xadd', 'movs', 'movsb', 'movsw', 'movsd',
    'add', 'sub', 'inc', 'dec', 'mul', 'imul', 'div', 'idiv', 'adc', 'sbb', 'neg',
    'and', 'or', 'xor', 'not', 'test',
    'shl', 'shr', 'sal', 'sar', 'rol', 'ror', 'rcl', 'rcr', 'bt', 'bts', 'btr', 'btc', 'bsf', 'bsr',
    'jmp', 'jz', 'jnz', 'je', 'jne', 'jg', 'jge', 'jl', 'jle', 'jo', 'jno', 'js', 'jns', 'jp', 'jnp',
    'jc', 'jnc', 'ja', 'jae', 'jb', 'jbe', 'call', 'ret', 'retn', 'int', 'into', 'iret', 'loop', 'loope', 'loopne',
    'cmp', 'cmps', 'cmpsb', 'cmpsw', 'cmpsd',
    'pusha', 'pushad', 'popa', 'popad', 'enter', 'leave',
    'stos', 'stosb', 'stosw', 'stosd', 'lods', 'lodsb', 'lodsw', 'lodsd', 'scas', 'scasb', 'scasw', 'scasd',
    'clc', 'stc', 'cli', 'sti', 'cld', 'std', 'cmc',
    'nop', 'hlt', 'wait', 'rdtsc', 'cpuid', 'in', 'out', 'ins', 'outs', 'int3', 'syscall', 'sysenter', 'sysexit',
    'fld', 'fst', 'fstp', 'fadd', 'fsub', 'fmul', 'fdiv', 'fcom', 'fcomp', 'fxch', 'fild', 'fist', 'fistp',
    'movaps', 'movups', 'movdqa', 'movdqu', 'addps', 'subps', 'mulps', 'divps', 'xorps', 'andps', 'orps',
    'sete', 'setne', 'setg', 'setge', 'setl', 'setle', 'seto', 'setno', 'sets', 'setns', 'setp', 'setnp',
    'seta', 'setae', 'setb', 'setbe', 'lahf', 'sahf', 'cbw', 'cwd', 'cdq', 'cwde', 'cdqe',
    'VirtualAlloc', 'VirtualProtect', 'ResumeThread', 'IsDebuggerPresent', 'CheckRemoteDebuggerPresent',
    'pause', 'rep', 'repe', 'repne', 'repnz', 'repz'
}

# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data
labels = pd.read_csv("./data/big2015/big2015_Labels.csv")
x = labels['Id'].values
y = labels['Class'].values.astype(int) - 1
x_train_name, x_test_name, y_train_lable, y_test_lable = train_test_split(
    x, y, test_size=cfg.test_size, random_state=cfg.SEED
)

# Data paths
img_train_path = './data/big2015/gray_images_file_name_train_big'
img_test_path = './data/big2015/gray_images_file_name_test_big'
wavelet_seq_train_path = './data/big2015/dealwith_data/wavelet_sequences_train'
wavelet_seq_test_path = './data/big2015/dealwith_data/wavelet_sequences_test'

# Instruction set and vocabulary
vocab_size = len(instruction_set) + 1
print(f"Instruction set size: {len(instruction_set)}")
print(f"Vocabulary size: {vocab_size}")

# Load preprocessed instructions
train_instructions_save_path = "./data/big2015/dealwith_data/instructions_train_remove_the_same.pkl"
test_instructions_save_path = "./data/big2015/dealwith_data/instructions_test_remove_the_same.pkl"
with open(train_instructions_save_path, "rb") as f:
    train_instructions = pickle.load(f)
with open(test_instructions_save_path, "rb") as f:
    test_instructions = pickle.load(f)

# Data augmentation (used during training)
train_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
])

# No transformation for testing
test_transform = transforms.Compose([])

# Dataset definition
class MultimodalDataset(Dataset):
    def __init__(self, preprocessed_instructions, wavelet_seq_path, gray_img_path, fl_names, y_label, is_train=True):
        self.preprocessed_instructions = preprocessed_instructions
        self.wavelet_seq_path = wavelet_seq_path
        self.gray_root_path = gray_img_path
        self.fl_names = fl_names
        self.wavelet_seq_file_path = [os.path.join(self.wavelet_seq_path, f"{name}.pkl") for name in self.fl_names]
        self.gray_img_file_path = [os.path.join(self.gray_root_path, f"{name}.png") for name in self.fl_names]
        self.y_data = y_label
        self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
        self.vocab = {instr: idx + 1 for idx, instr in enumerate(instruction_set)}
        self.is_train = is_train
        self.transform = train_transform if is_train else test_transform

    def __getitem__(self, index):
        # Instruction sequence
        instructions = self.preprocessed_instructions[index][:cfg.max_seq_len]
        if self.is_train and np.random.rand() < 0.1:
            np.random.shuffle(instructions)
        instr_tensor = torch.tensor([self.vocab.get(instr, 0) for instr in instructions], dtype=torch.long)

        # Wavelet sequence
        try:
            with open(self.wavelet_seq_file_path[index], 'rb') as f:
                wavelet_seq = pickle.load(f)  # [seq_len, 4]
            wavelet_seq = torch.tensor(wavelet_seq, dtype=torch.float32)
            if wavelet_seq.size(0) > cfg.max_seq_len:
                wavelet_seq = wavelet_seq[:cfg.max_seq_len]
            else:
                wavelet_seq = F.pad(wavelet_seq, (0, 0, 0, cfg.max_seq_len - wavelet_seq.size(0)))
        except Exception as e:
            # print(f"Error loading wavelet sequence {self.wavelet_seq_file_path[index]}: {e}")
            wavelet_seq = torch.zeros((cfg.max_seq_len, 4), dtype=torch.float32)

        # Grayscale image
        gray_image = cv2.imread(self.gray_img_file_path[index])
        if gray_image is None:
            raise FileNotFoundError(f"Grayscale image not found: {self.gray_img_file_path[index]}")
        gray_image = cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)
        gray_image = self.transform(torch.from_numpy(gray_image).permute(2, 0, 1)).permute(1, 2, 0).numpy()
        inputs = self.processor(images=gray_image, return_tensors="pt")

        label = torch.tensor(self.y_data[index], dtype=torch.long)
        return instr_tensor, wavelet_seq, inputs, label

    def __len__(self):
        return len(self.y_data)

# Custom batch collation function
def collate_fn(batch):
    instr_seqs, wavelet_seqs, inputs, labels = zip(*batch)
    instr_seqs = torch.nn.utils.rnn.pad_sequence(instr_seqs, batch_first=True, padding_value=0)
    wavelet_seqs = torch.stack(wavelet_seqs)
    pixel_values = torch.cat([inp["pixel_values"] for inp in inputs], dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return instr_seqs, wavelet_seqs, {"pixel_values": pixel_values}, labels

# Model definition
class MultiModalModel(nn.Module):
    def __init__(self, num_classes, vocab_size, wavelet_dim=4, embedding_dim=256, hidden_size=768, num_heads=8, num_layers=6):
        super(MultiModalModel, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.vit_dim = 768

        self.instr_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.instr_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4, dim_feedforward=512, dropout=0.3, batch_first=True),
            num_layers=4
        )
        self.instr_attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True)
        self.instr_proj = nn.Linear(embedding_dim, hidden_size)

        self.wavelet_fc = nn.Linear(wavelet_dim, embedding_dim)
        self.wavelet_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4, dim_feedforward=512, dropout=0.3, batch_first=True),
            num_layers=4
        )
        self.wavelet_attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True)
        self.wavelet_proj = nn.Linear(embedding_dim, hidden_size)

        self.gray_fc = nn.Linear(self.vit_dim, hidden_size)

        self.modality_attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)

        self.fusion_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads, dim_feedforward=2048, dropout=0.3, batch_first=True),
            num_layers=num_layers
        )
        self.norm = nn.LayerNorm(hidden_size)
        self.residual_fc = nn.Linear(hidden_size * 3, hidden_size)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, instr_seq, wavelet_seq, images):
        pixel_values = images["pixel_values"].to(device)

        instr_embedded = self.instr_embedding(instr_seq)
        instr_features = self.instr_transformer(instr_embedded)
        attn_output, _ = self.instr_attention(instr_features, instr_features, instr_features)
        instr_features = attn_output.mean(dim=1)
        instr_features = self.instr_proj(instr_features)

        wavelet_embedded = self.wavelet_fc(wavelet_seq)
        wavelet_features = self.wavelet_transformer(wavelet_embedded)
        wavelet_attn, _ = self.wavelet_attention(wavelet_features, wavelet_features, wavelet_features)
        wavelet_features = wavelet_attn.mean(dim=1)
        wavelet_features = self.wavelet_proj(wavelet_features)

        gray_features = self.vit(pixel_values).last_hidden_state.mean(dim=1)
        gray_features = self.gray_fc(gray_features)

        modality_features = torch.stack((instr_features, wavelet_features, gray_features), dim=1)
        attn_output, _ = self.modality_attention(modality_features, modality_features, modality_features)

        fused_features = self.fusion_transformer(attn_output)
        fused_features = fused_features.mean(dim=1)

        concat_features = torch.cat((instr_features, wavelet_features, gray_features), dim=1)
        residual = self.residual_fc(concat_features)
        fused_features = self.norm(fused_features + residual)

        output = self.fc(fused_features)
        return output

# Training and evaluation function
save_dir = './results/mmt_vit/mmt_vit_results/mmt-ViT_epoch'
def train_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs, accum_steps=4, save_dir=save_dir):
    writer = SummaryWriter('./results/mmt_vit/mmt_vit_results/runs/mmt-ViT_multimodal_model')
    scaler = GradScaler('cuda')
    os.makedirs(save_dir, exist_ok=True)
    train_losses, train_accs, test_accs, test_losses = [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for i, (instr_seqs, wavelet_seqs, images, labels) in enumerate(train_loader):
            instr_seqs = instr_seqs.to(device)
            wavelet_seqs = wavelet_seqs.to(device)
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)

            with autocast('cuda'):
                outputs = model(instr_seqs, wavelet_seqs, images)
                loss = criterion(outputs, labels)

            scaler.scale(loss / accum_steps).backward()
            if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * labels.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)

        model.eval()
        test_loss = 0.0
        test_corrects = 0
        with torch.no_grad():
            for instr_seqs, wavelet_seqs, images, labels in test_loader:
                instr_seqs = instr_seqs.to(device)
                wavelet_seqs = wavelet_seqs.to(device)
                images = {k: v.to(device) for k, v in images.items()}
                labels = labels.to(device)
                with autocast('cuda'):
                    outputs = model(instr_seqs, wavelet_seqs, images)
                    loss = criterion(outputs, labels)
                test_loss += loss.item() * labels.size(0)
                _, preds = torch.max(outputs, 1)
                test_corrects += torch.sum(preds == labels.data)

        test_loss = test_loss / len(test_loader.dataset)
        test_acc = test_corrects.double() / len(test_loader.dataset)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        writer.add_scalar("Train/Loss", epoch_loss, epoch)
        writer.add_scalar("Train/Accuracy", epoch_acc, epoch)
        writer.add_scalar("Test/Loss", test_loss, epoch)
        writer.add_scalar("Test/Accuracy", test_acc, epoch)

        print(f"Epoch {epoch}/{num_epochs}, Train Loss: {epoch_loss:.8f}, Train Acc: {epoch_acc:.8f}, Test Loss: {test_loss:.8f}, Test Acc: {test_acc:.8f}")
        scheduler.step()

        checkpoint_path = os.path.join(save_dir, f'epoch{epoch}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_losses': train_losses,
            'train_accs': train_accs,
            'test_losses': test_losses,
            'test_accs': test_accs
        }, checkpoint_path)

    best_epoch = np.argmax([acc.cpu() for acc in test_accs])
    best_checkpoint_path = os.path.join(save_dir, f'epoch{best_epoch}.pth')
    checkpoint = torch.load(best_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Best model loaded from {best_checkpoint_path}, best test accuracy: {test_accs[best_epoch]:.8f}")

    save_path = './results/mmt_vit/mmt_vit_results/model'
    os.makedirs(save_path, exist_ok=True)
    model_path = os.path.join(save_path, 'mmt-ViT_multimodal_model.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, model_path)
    writer.close()

    return model

# Testing function
def test(model, test_loader):
    model.eval()
    total_correct = 0
    total_samples = 0
    y_pred_proba = []
    y_true = []
    f1_metric = MulticlassF1Score(num_classes=cfg.class_num).to(device)

    with torch.no_grad():
        for instr_seqs, wavelet_seqs, images, labels in test_loader:
            instr_seqs = instr_seqs.to(device)
            wavelet_seqs = wavelet_seqs.to(device)
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)
            with autocast('cuda'):
                outputs = model(instr_seqs, wavelet_seqs, images)
            _, preds = torch.max(outputs, 1)
            probabilities = F.softmax(outputs, dim=1)
            y_pred_proba.extend(probabilities.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
            f1_metric.update(preds, labels)

    accuracy = total_correct / total_samples
    y_pred = np.argmax(np.array(y_pred_proba), axis=1)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_metric.compute()

    # Compute confusion matrix
    conf_matrix = confusion_matrix(y_true, y_pred, labels=range(cfg.class_num))
    conf_matrix_norm = confusion_matrix(y_true, y_pred, labels=range(cfg.class_num), normalize='true')

    # Save results to JSON
    results = {
        "accuracy": float(accuracy),
        "precision": float(precision),
        "recall": float(recall),
        "f1_score": float(f1),
        "confusion_matrix": conf_matrix.tolist(),
        "normalized_confusion_matrix": conf_matrix_norm.tolist(),
        "class_labels": [f"Class {i}" for i in range(cfg.class_num)]
    }
    json_path = './results/mmt_vit/mmt_vit_results/metrics/mmt-ViT_multimodal_model.json'
    os.makedirs(os.path.dirname(json_path), exist_ok=True)
    try:
        with open(json_path, 'w') as f:
            json.dump(results, f, indent=4)
        print(f"Test results saved to {json_path}")
    except Exception as e:
        print(f"Error saving JSON results: {e}")

    # Print evaluation metrics
    print(f'Accuracy: {accuracy:.4f}')
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    # Visualize confusion matrix
    class_labels = [f'Class {i}' for i in range(cfg.class_num)]
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_labels,
                yticklabels=class_labels)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.savefig('./results/mmt_vit/mmt_vit_results/metrics/mmt-ViT_confusion_matrix.png')
    plt.show()


    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix_norm, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=[f'Class {i}' for i in range(cfg.class_num)],
                yticklabels=[f'Class {i}' for i in range(cfg.class_num)])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Normalized Confusion Matrix')
    plt.savefig('./results/mmt_vit/mmt_vit_results/metrics/mmt-ViT_Normalized_confusion_matrix.png')
    plt.show()



# Main program
if __name__ == '__main__':
    # Initialize datasets
    dataset_train = MultimodalDataset(
        train_instructions, wavelet_seq_train_path, img_train_path, x_train_name, y_train_lable, is_train=True
    )
    dataset_test = MultimodalDataset(
        test_instructions, wavelet_seq_test_path, img_test_path, x_test_name, y_test_lable, is_train=False
    )
    train_loader = DataLoader(dataset_train, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(dataset_test, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    # Initialize model
    model = MultiModalModel(num_classes=cfg.class_num, vocab_size=vocab_size, wavelet_dim=4, hidden_size=768).to(device)

    # Freeze ViT parameters
    for name, param in model.vit.named_parameters():
        if "encoder.layer" in name and int(name.split('.')[2]) >= 8:
            param.requires_grad = True
        else:
            param.requires_grad = False

    # Loss function and optimizer
    class_weights = torch.tensor(
        [1.0 / max(Counter(y_train_lable).get(i, 1), 1) for i in range(cfg.class_num)], dtype=torch.float
    ).to(device)
    print(f"Class weights: {class_weights}")
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=50)


    # Debug initial forward pass
    model.eval()
    print("Debugging initial forward pass...")
    with torch.no_grad():
        for instr_seqs, wavelet_seqs, images, labels in train_loader:
            instr_seqs = instr_seqs.to(device)
            wavelet_seqs = wavelet_seqs.to(device)
            images = {k: v.to(device) for k, v in images.items()}
            outputs = model(instr_seqs, wavelet_seqs, images)
            probs = F.softmax(outputs, dim=1)
            print(f"Initial sample probabilities: {probs[0]}")
            print(f"Initial predictions: {torch.argmax(probs, dim=1)[:5]}")
            print(f"True labels: {labels[:5]}")
            break

    # Start training
    print("Starting training...")
    model = train_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs=cfg.num_epochs)

    # Test the best model
    print("Testing the best model...")
    test(model, test_loader)

