import pickle
import torch
import numpy as np
from PIL import Image
import os
import re
from tqdm import tqdm
import pywt
from sklearn.model_selection import train_test_split
import pandas as pd
from multiprocessing import Pool
from pathlib import Path

def get_file_path(dataset_name, suffix='.bytes'):
    folder_path = './data/big2015/dataset_big2015'
    file_names = [f"{name}{suffix}" for name in dataset_name]
    return [os.path.join(folder_path, name) for name in file_names]

def read_bytes(file_path):
    res_bytes = []
    try:
        with open(file_path, mode='r', encoding='utf-8', errors='ignore') as fp:
            for byte_line in fp.readlines():
                str_bytes = byte_line.split(" ")
                for str_byte in str_bytes:
                    if str_byte[0] == '?':
                        continue
                    try:
                        byte = int(str_byte, 16)
                        if byte <= 0xFF:
                            res_bytes.append(byte)
                    except ValueError:
                        continue
        return res_bytes
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []

def dealwith_data(path_src, path_des, save_name):
    os.makedirs(os.path.dirname(path_des), exist_ok=True)
    if os.path.exists(path_des):
        print(f"Loading processed data from {path_des}...")
        with open(path_des, "rb") as f:
            save_name = pickle.load(f)
    else:
        print("Processing data...")
        for path in tqdm(path_src, desc="Reading bytes"):
            data = read_bytes(path)
            tensor_data = torch.tensor(data, dtype=torch.float32).to('cpu')
            save_name.append(tensor_data)
            del tensor_data
            torch.cuda.empty_cache()
        print(f"Saving processed data to {path_des}...")
        with open(path_des, "wb") as f:
            pickle.dump(save_name, f)
    return save_name

def get_gray_image(array_byte):
    byte_len = len(array_byte)
    file_size = int(byte_len / 1024)
    file_size_array = [10, 20, 40, 80, 160, 320, 480, 800, 1000]
    image_width_array = [32, 64, 128, 256, 384, 512, 640, 768, 896]
    image_width = next((w for f, w in zip(file_size_array, image_width_array) if file_size < f), 1024)
    image_height = int(byte_len / image_width)
    image_bytes = np.zeros([image_height, image_width], dtype=np.uint8)
    k = 0
    for i in range(image_height):
        for j in range(image_width):
            if k < byte_len:
                image_bytes[i][j] = array_byte[k]
                k += 1
    return image_bytes

def get_file_name(filepath):
    return os.path.splitext(os.path.basename(filepath))[0]

def get_gray_images(base_dir, gray_images_name, path_name, data_bytes):
    dest_path = os.path.join(base_dir, "..", gray_images_name)
    os.makedirs(dest_path, exist_ok=True)
    count = 0
    for idx, fe_path in enumerate(tqdm(path_name, desc="Generating gray images")):
        img_path = os.path.join(dest_path, f"{get_file_name(fe_path)}.png")
        if os.path.exists(img_path):
            continue
        array_byte = data_bytes[idx]
        if len(array_byte) == 0:
            array_byte = [0] * (224 * 224)
        img = get_gray_image(array_byte)
        img = np.uint8(img)
        Image.fromarray(img).save(img_path)
        count += 1
    print(f"Get gray image ok! Generated {count} images")


def wavelet_transform(data, wavelet='db4', level=3, max_seq_len=2048):
    data = np.array(data, dtype=np.float32)
    if len(data) == 0:
        print("Warning: Empty input data, returning default zeros")
        return np.zeros((max_seq_len, 4), dtype=np.float32)

    try:
        min_length = 2**level
        if len(data) < min_length:
            data = np.pad(data, (0, min_length - len(data)), mode='constant')
        else:
            pad_length = (min_length - (len(data) % min_length)) % min_length
            data = np.pad(data, (0, pad_length), mode='constant')

        coeffs = pywt.wavedec(data, wavelet=wavelet, level=level, mode='symmetric')
        cA3 = coeffs[0]
        cD_list = coeffs[1:][::-1]
        cD3, cD2, cD1 = (cD_list + [np.zeros_like(cA3)] * 3)[:3]

        max_len = min(max(len(c) for c in [cA3, cD3, cD2, cD1]), max_seq_len)
        norm_coeffs = []
        for c in [cA3, cD3, cD2, cD1]:
            if len(c) > max_len:
                c = c[:max_len]
            elif len(c) < max_len:
                c = np.pad(c, (0, max_len - len(c)), mode='constant')
            norm_coeffs.append(c)

        seq = np.stack(norm_coeffs, axis=-1)
        return seq
    except Exception as e:
        print(f"Wavelet transform error: {e}")
        return np.zeros((max_seq_len, 4), dtype=np.float32)

def process_wavelet_file(args):
    file_path, output_dir, wavelet, level, max_seq_len = args
    sample_id = Path(file_path).stem
    output_path = os.path.join(output_dir, f"{sample_id}.pkl")

    try:
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            seq = np.zeros((max_seq_len, 4), dtype=np.float32)
            os.makedirs(output_dir, exist_ok=True)
            with open(output_path, 'wb') as f:
                pickle.dump(seq, f)
            return sample_id, True, "File not found, saved zero sequence"
        if os.path.getsize(file_path) == 0:
            print(f"Empty file: {file_path}")
            seq = np.zeros((max_seq_len, 4), dtype=np.float32)
            os.makedirs(output_dir, exist_ok=True)
            with open(output_path, 'wb') as f:
                pickle.dump(seq, f)
            return sample_id, True, "Empty file, saved zero sequence"

        data = read_bytes(file_path)
        seq = wavelet_transform(data, wavelet=wavelet, level=level, max_seq_len=max_seq_len)
        os.makedirs(output_dir, exist_ok=True)
        with open(output_path, 'wb') as f:
            pickle.dump(seq, f)
        if len(data) == 0:
            return sample_id, True, "Invalid bytes (all ??), saved zero sequence"
        return sample_id, True, None
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        seq = np.zeros((max_seq_len, 4), dtype=np.float32)
        os.makedirs(output_dir, exist_ok=True)
        with open(output_path, 'wb') as f:
            pickle.dump(seq, f)
        return sample_id, True, f"Error: {str(e)}, saved zero sequence"


def generate_wavelet_sequences(dataset_name, output_dir, wavelet='db4', level=3, max_seq_len=2048, num_processes=2):
    os.makedirs(output_dir, exist_ok=True)
    file_paths = get_file_path(dataset_name, suffix='.bytes')
    sequence_paths = []
    failed_files = []

    print(f"Processing {len(file_paths)} files for wavelet sequences...")
    args = [(fp, output_dir, wavelet, level, max_seq_len) for fp in file_paths]

    try:
        with Pool(processes=num_processes) as pool:
            results = list(tqdm(pool.imap(process_wavelet_file, args), total=len(file_paths)))
            pool.close()
            pool.join()
    except Exception as e:
        print(f"Multiprocessing error: {e}")
        return []

    success = 0
    for sample_id, status, error_msg in results:
        if status:
            sequence_paths.append(os.path.join(output_dir, f"{sample_id}.pkl"))
            success += 1
        else:
            failed_files.append((sample_id, error_msg))

    print(f"Generated {success} wavelet sequences, {len(failed_files)} failed.")
    if failed_files:
        print("Failed files:", [(fid, msg) for fid, msg in failed_files])

    path_list_save = os.path.join(output_dir, 'sequence_paths.pkl')
    with open(path_list_save, 'wb') as f:
        pickle.dump(sequence_paths, f)

    return sequence_paths

def generate_wavelet_sequences(dataset_name, output_dir, wavelet='db4', level=3, max_seq_len=2048, num_processes=2):
    os.makedirs(output_dir, exist_ok=True)
    file_paths = get_file_path(dataset_name, suffix='.bytes')
    sequence_paths = []
    failed_files = []
    invalid_files = []

    print(f"Processing {len(file_paths)} files for wavelet sequences...")
    args = [(fp, output_dir, wavelet, level, max_seq_len) for fp in file_paths]

    try:
        with Pool(processes=num_processes) as pool:
            results = list(tqdm(pool.imap(process_wavelet_file, args), total=len(file_paths), desc="Generating sequences"))
            pool.close()
            pool.join()
    except Exception as e:
        print(f"Multiprocessing error: {e}")
        print("Falling back to single process...")
        results = [process_wavelet_file(arg) for arg in tqdm(args, desc="Generating sequences (single process)")]

    success = 0
    for sample_id, status, error_msg in results:
        if status:
            sequence_paths.append(os.path.join(output_dir, f"{sample_id}.pkl"))
            success += 1
            if error_msg and ("Invalid bytes" in error_msg or "File not found" in error_msg or "Empty file" in error_msg):
                invalid_files.append((sample_id, error_msg))
        else:
            failed_files.append((sample_id, error_msg))

    print(f"Generated {success} wavelet sequences, {len(failed_files)} failed.")
    if failed_files:
        print("Failed files:", [(fid, msg) for fid, msg in failed_files])
    if invalid_files:
        print("Invalid or missing byte files:", [(fid, msg) for fid, msg in invalid_files])
        log_path = os.path.join(output_dir, 'invalid_files.log')
        with open(log_path, 'w') as f:
            for fid, msg in invalid_files:
                f.write(f"{fid}: {msg}\n")
        print(f"Invalid files logged to {log_path}")

    path_list_save = os.path.join(output_dir, 'sequence_paths.pkl')
    with open(path_list_save, 'wb') as f:
        pickle.dump(sequence_paths, f)

    return sequence_paths

instruction_set = {
    'mov', 'movsx', 'movzx', 'push', 'pop', 'lea', 'xchg', 'cmpxchg', 'xadd', 'movs', 'movsb', 'movsw', 'movsd',
    'add', 'sub', 'inc', 'dec', 'mul', 'imul', 'div', 'idiv', 'adc', 'sbb', 'neg',
    'and', 'or', 'xor', 'not', 'test',
    'shl', 'shr', 'sal', 'sar', 'rol', 'ror', 'rcl', 'rcr', 'bt', 'bts', 'btr', 'btc', 'bsf', 'bsr',
    'jmp', 'jz', 'jnz', 'je', 'jne', 'jg', 'jge', 'jl', 'jle', 'jo', 'jno', 'js', 'jns', 'jp', 'jnp',
    'jc', 'jnc', 'ja', 'jae', 'jb', 'jbe', 'call', 'ret', 'retn', 'int', 'into', 'iret', 'loop', 'loope', 'loopne',
    'cmp', 'cmps', 'cmpsb', 'cmpsw', 'cmpsd',
    'pusha', 'pushad', 'popa', 'popad', 'enter', 'leave',
    'stos', 'stosb', 'stosw', 'stosd', 'lods', 'lodsb', 'lodsw', 'lodsd', 'scas', 'scasb', 'scasw', 'scasd',
    'clc', 'stc', 'cli', 'sti', 'cld', 'std', 'cmc',
    'nop', 'hlt', 'wait', 'rdtsc', 'cpuid', 'in', 'out', 'ins', 'outs', 'int3', 'syscall', 'sysenter', 'sysexit',
    'fld', 'fst', 'fstp', 'fadd', 'fsub', 'fmul', 'fdiv', 'fcom', 'fcomp', 'fxch', 'fild', 'fist', 'fistp',
    'movaps', 'movups', 'movdqa', 'movdqu', 'addps', 'subps', 'mulps', 'divps', 'xorps', 'andps', 'orps',
    'sete', 'setne', 'setg', 'setge', 'setl', 'setle', 'seto', 'setno', 'sets', 'setns', 'setp', 'setnp',
    'seta', 'setae', 'setb', 'setbe', 'lahf', 'sahf', 'cbw', 'cwd', 'cdq', 'cwde', 'cdqe',
    'VirtualAlloc', 'VirtualProtect', 'ResumeThread', 'IsDebuggerPresent', 'CheckRemoteDebuggerPresent',
    'pause', 'rep', 'repe', 'repne', 'repnz', 'repz'
}

def extract_instructions(asm_file_path, instruction_set):
    instructions = []
    address_pattern = re.compile(r'^\.\w+:[0-9A-Fa-f]{8}\s+')
    bytecode_pattern = re.compile(r'^\s*[0-9A-Fa-f]{2}(?:\s+[0-9A-Fa-f]{2})*\s+')
    if not os.path.exists(asm_file_path):
        # print(f"File is not exist: {asm_file_path}")
        return instructions

    with open(asm_file_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            if not line.strip() or line.strip().startswith(';'):
                continue
            cleaned_line = address_pattern.sub('', line).strip()
            if bytecode_pattern.match(cleaned_line):
                cleaned_line = bytecode_pattern.sub('', cleaned_line).strip()
            parts = re.split(r'\s+', cleaned_line)
            if not parts:
                continue
            for part in parts:
                candidate = part.strip(',').lower()
                if candidate in instruction_set:
                    if not instructions or instructions[-1] != candidate:
                        instructions.append(candidate)
                    break
    return instructions

def preprocess_all_instructions(asm_file_paths, instruction_set, save_path_instructions):
    all_instructions = []
    for asm_path in tqdm(asm_file_paths, desc="Extracting instructions"):
        instructions = extract_instructions(asm_path, instruction_set)
        all_instructions.append(instructions)
    if save_path_instructions:
        os.makedirs(os.path.dirname(save_path_instructions), exist_ok=True)
        with open(save_path_instructions, "wb") as f:
            pickle.dump(all_instructions, f)
        print(f"Saved preprocessed instructions to {save_path_instructions}")
    return all_instructions

if __name__ == '__main__':
    class cfg:
        SEED = 42
        batch_size = 8
        test_size = 0.2
        max_seq_len = 2048
        class_num = 9

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    labels = pd.read_csv("./data/big2015/big2015_Labels.csv")
    x = labels['Id'].values
    y = labels['Class'].values.astype(int) - 1
    x_train_name, x_test_name, y_train_lable, y_test_lable = train_test_split(
        x, y, test_size=cfg.test_size, random_state=cfg.SEED
    )

    wavelet_seq_train_path = './data/big2015/dealwith_data/wavelet_sequences_train'
    wavelet_seq_test_path = './data/big2015/dealwith_data/wavelet_sequences_test'
    gray_train_path = './gray_images_file_name_train_big'
    gray_test_path = './gray_images_file_name_test_big'
    bytes_train_save_path = "./data/big2015/dealwith_data/train_bytes.pkl"
    bytes_test_save_path = "./data/big2015/dealwith_data/test_bytes.pkl"
    instructions_train_save_path = "./data/big2015/dealwith_data/instructions_train_remove_the_same.pkl"
    instructions_test_save_path = "./data/big2015/dealwith_data/instructions_test_remove_the_same.pkl"

    train_bytes = []
    test_bytes = []
    train_bytes_paths = get_file_path(x_train_name, suffix='.bytes')
    test_bytes_paths = get_file_path(x_test_name, suffix='.bytes')
    train_bytes = dealwith_data(train_bytes_paths, bytes_train_save_path, train_bytes)
    test_bytes = dealwith_data(test_bytes_paths, bytes_test_save_path, test_bytes)

    get_gray_images('./data/big2015/dataset_big2015', gray_train_path, train_bytes_paths, train_bytes)
    get_gray_images('./data/big2015/dataset_big2015', gray_test_path, test_bytes_paths, test_bytes)

    print("Generating wavelet sequences for training set...")
    train_seq_paths = generate_wavelet_sequences(
        x_train_name, wavelet_seq_train_path, max_seq_len=cfg.max_seq_len, num_processes=2
    )
    print("Generating wavelet sequences for test set...")
    test_seq_paths = generate_wavelet_sequences(
        x_test_name, wavelet_seq_test_path, max_seq_len=cfg.max_seq_len, num_processes=2
    )

    train_asm_paths = get_file_path(x_train_name, suffix='.asm')
    test_asm_paths = get_file_path(x_test_name, suffix='.asm')
    train_instructions = preprocess_all_instructions(
        train_asm_paths, instruction_set, instructions_train_save_path
    )
    test_instructions = preprocess_all_instructions(
        test_asm_paths, instruction_set, instructions_test_save_path
    )


