# config.py
"""
Configuration file for malware classification models (DGSM_SCAM_GAT and MMT-ViT)
"""

import torch
import os

# General settings for all models and datasets
class GeneralConfig:
    """General configuration parameters applicable to all experiments."""
    SEED = 42  # Random seed for reproducibility
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Device for computation
    TEST_SIZE = 0.2  # Proportion of dataset used for testing
    BATCH_SIZE = 8  # Batch size for training and evaluation
    NUM_EPOCHS = 100  # Number of training epochs
    FINE_TUNE_EPOCHS = 10  # Number of fine-tuning epochs

# Dataset-specific configurations
class DatasetConfig:
    """Paths and settings for datasets used in experiments."""
    DATA_ROOT = "./data"  # Base directory for all datasets

    # big2015 dataset
    BIG2015_ROOT = os.path.join(DATA_ROOT, "big2015")
    BIG2015_LABELS = os.path.join(BIG2015_ROOT, "big2015_Labels.csv")
    BIG2015_GRAY_IMAGES_TRAIN = os.path.join(BIG2015_ROOT, "gray_images_file_name_train_big")
    BIG2015_GRAY_IMAGES_TEST = os.path.join(BIG2015_ROOT, "gray_images_file_name_test_big")
    BIG2015_WAVELET_SEQ_TRAIN = os.path.join(BIG2015_ROOT, "dealwith_data/wavelet_sequences_train")
    BIG2015_WAVELET_SEQ_TEST = os.path.join(BIG2015_ROOT, "dealwith_data/wavelet_sequences_test")
    BIG2015_INSTRUCTIONS_TRAIN = os.path.join(BIG2015_ROOT, "dealwith_data/instructions_train_remove_the_same.pkl")
    BIG2015_INSTRUCTIONS_TEST = os.path.join(BIG2015_ROOT, "dealwith_data/instructions_test_remove_the_same.pkl")
    BIG2015_NUM_CLASSES = 9  # Number of classes in big2015 dataset
    BIG2015_MAX_SEQ_LEN = 2048  # Maximum sequence length for wavelet and instruction sequences

    # malimg_25 dataset
    MALIMG_25_ROOT = os.path.join(DATA_ROOT, "big2015_yz/malimg_25")
    MALIMG_25_INPUT = os.path.join(MALIMG_25_ROOT, "data_in")
    MALIMG_25_NUM_CLASSES = 25  # Number of classes in malimg_25 dataset

    # Malevis_malimg_31 dataset
    MALEVIS_31_ROOT = os.path.join(DATA_ROOT, "big2015_yz/Malevis_malimg_31")
    MALEVIS_31_INPUT = os.path.join(MALEVIS_31_ROOT, "in")
    MALEVIS_31_NUM_CLASSES = 31  # Number of classes in Malevis_malimg_31 dataset

    # API call dataset for fine-tuning
    DYNAMIC_API_CALLS = os.path.join(DATA_ROOT, "dynamic_api_call_data/dynamic_api_call_sequence_20000.csv")
    MAL_API_2019 = os.path.join(DATA_ROOT, "mal_api_2019/merged_api_index_data.csv")
    API_FEATURE_DIM = 100  # Number of API call features
    API_NUM_FEATURES = 309  # Vocabulary size for API calls

# Model-specific configurations
class DGSM_SCAM_GAT_Config:
    """Configuration for DGSM_SCAM_GAT_Enhanced model."""
    NUM_FEATURES = DatasetConfig.API_NUM_FEATURES  # Number of input features (API calls)
    EMBEDDING_DIM = 256  # Embedding dimension for API calls
    N_HID = 256  # Hidden dimension for GAT and other layers
    N_CLASS = 2  # Number of classes for binary classification
    DROPOUT = 0.3  # Dropout rate
    N_HEADS = 16  # Number of attention heads for Transformer and SCAM
    N_GAT_HEADS = 8  # Number of attention heads for GATConv
    SEQ_LEN = 100  # Sequence length for API calls
    LEARNING_RATE = 0.00005  # Learning rate for training
    FINE_TUNE_LR = 0.00005  # Learning rate for fine-tuning
    WEIGHT_DECAY = 1e-3  # Weight decay for AdamW optimizer
    MIN_LR = 1e-6  # Minimum learning rate for cosine scheduler
    LABEL_SMOOTHING = 0.05  # Label smoothing for CrossEntropyLoss
    GRAD_CLIP_NORM = 1.0  # Gradient clipping norm
    RESULTS_DIR = "./results/dgsm_scam_gat"  # Directory for saving results
    MODEL_PATH = os.path.join(RESULTS_DIR, "models/dynamic_api_DGSM_SCAM_GAT_Improved.pth")
    FINE_TUNE_MODEL_PATH = os.path.join(RESULTS_DIR, "models/mal_api_2019_DGSM_SCAM_GAT_Improved.pth")
    METRICS_DIR = os.path.join(RESULTS_DIR, "metrics")
    EVALUATION_DIR = os.path.join(RESULTS_DIR, "evaluation")
    RUNS_DIR = os.path.join(RESULTS_DIR, "runs")
    EPOCH_DIR = os.path.join(RESULTS_DIR, "dynamic_api_epoch")

class MMT_ViT_Config:
    """Configuration for MMT-ViT and GrayOnlyModel."""
    NUM_CLASSES = DatasetConfig.BIG2015_NUM_CLASSES  # Number of classes for big2015
    MALIMG_25_NUM_CLASSES = DatasetConfig.MALIMG_25_NUM_CLASSES  # Number of classes for malimg_25
    MALEVIS_31_NUM_CLASSES = DatasetConfig.MALEVIS_31_NUM_CLASSES  # Number of classes for Malevis_malimg_31
    EMBEDDING_DIM = 256  # Embedding dimension for instructions
    HIDDEN_SIZE = 768  # Hidden size for transformer and ViT outputs
    WAVELET_DIM = 4  # Dimension of wavelet sequences
    NUM_HEADS = 8  # Number of attention heads for fusion transformer
    NUM_LAYERS = 6  # Number of transformer layers
    DROPOUT = 0.3  # Dropout rate
    LEARNING_RATE = 0.0001  # Learning rate for training
    FINE_TUNE_LR = 1e-5  # Learning rate for fine-tuning
    WEIGHT_DECAY = 1e-5  # Weight decay for AdamW optimizer
    ACCUM_STEPS = 4  # Gradient accumulation steps
    MAX_LR = 0.001  # Maximum learning rate for OneCycleLR
    FINE_TUNE_MAX_LR = 1e-5  # Maximum learning rate for CosineAnnealingLR
    RESULTS_DIR = "./results/mmt_vit"  # Directory for saving results
    MODEL_PATH = os.path.join(RESULTS_DIR, "mmt_vit_results/model/mmt-ViT_multimodal_model.pth")
    GRAY_ONLY_MODEL_PATH = os.path.join(RESULTS_DIR, "mmt_vit_results/model/gray_only.pth")
    MALIMG_25_FINE_TUNE_PATH = os.path.join(RESULTS_DIR, "big2015_yz/yz_results_25/malimg_25class_finetuned_best.pth")
    MALEVIS_31_FINE_TUNE_PATH = os.path.join(RESULTS_DIR, "big2015_yz/yz_results_31/Malevis_malimg_31class_finetuned_best.pth")
    METRICS_DIR = os.path.join(RESULTS_DIR, "mmt_vit_results/metrics")
    RUNS_DIR = os.path.join(RESULTS_DIR, "mmt_vit_results/runs")
    EPOCH_DIR = os.path.join(RESULTS_DIR, "mmt_vit_results/mmt-ViT_epoch")
    MALIMG_25_EVAL_DIR = os.path.join(RESULTS_DIR, "big2015_yz/yz_results_25")
    MALEVIS_31_EVAL_DIR = os.path.join(RESULTS_DIR, "big2015_yz/yz_results_31")
