# We will provide the complete code if the paper is accepted.

import torch
import numpy as np
from sklearn.cluster import KMeans




mu = torch.tensor([0.485, 0.456, 0.406]).cuda()
std = torch.tensor([0.229, 0.224, 0.225]).cuda()

IMG_SIZE = 224
PATCH_SIZE = 4
SEQ_SIZE = int((IMG_SIZE / PATCH_SIZE) * (IMG_SIZE / PATCH_SIZE)) # 196
# PATCH_NUM = 64
# CLUSTER_NUM = 3

###########################################
#   parameters setting
F = 1
CR = 0.8
###########################################





def calculate_fitness(model, minipatch_num, target_images, population, tru_labels):

    binary_population = population[0]
    rgb_population = population[1]

    batch_size, population_size, seq_size = binary_population.shape

    target_images = target_images.cuda()
    tru_labels = tru_labels.cuda()
    perturbation = torch.zeros_like(target_images).cuda()

    grid_size = IMG_SIZE // PATCH_SIZE  # e.g., 14
    tile_size = (PATCH_SIZE, PATCH_SIZE)  # e.g., 16x16


    fitness = np.zeros((batch_size, population_size))  

    for i in range(population_size): 
        # Reshape binary population[i] into grid and expand to full-size masks
        mask_array = binary_population[:, i].reshape(batch_size, grid_size, grid_size)
        masks = np.kron(mask_array, np.ones(tile_size, dtype=int))  # Expand patches
        masks = torch.tensor(masks, dtype=torch.float32).cuda()  # Shape: [batch_size, height, width]

        # Expand masks to include channel dimension
        masks = masks.unsqueeze(1)  # Shape: [batch_size, 1, height, width]

        perturbation = torch.zeros_like(target_images).cuda() 

        for b in range(batch_size): 
            for p in range(minipatch_num): 
                one_indices = np.where(binary_population[b, i] == 1)[0]
                if len(one_indices) > p:
                    x, y = divmod(one_indices[p], grid_size)
                    patch_rgb = rgb_population[b, i, p] 
                    patch_rgb = torch.tensor(patch_rgb, dtype=torch.float32, device='cuda') 
                    patch_rgb = patch_rgb / 255.0 
                    patch_rgb = (patch_rgb - mu) / std 
                    perturbation[b, :, x * PATCH_SIZE:(x + 1) * PATCH_SIZE, y * PATCH_SIZE:(y + 1) * PATCH_SIZE] = patch_rgb.view(3, 1, 1)
        
        # Create adversarial images for this individual
        adv_images = target_images * (1 - masks) + perturbation * masks

        # Forward pass for this population individual
        
        out = model(adv_images) # resnet,vgg
        
        # out, atten = model(adv_images) # vit

        # Calculate loss for this individual
        loss = torch.nn.functional.cross_entropy(out, tru_labels, reduction='none')  # Shape: [batch_size]
        fitness[:, i] = loss.detach().cpu().numpy()  # Detach and store fitness for this individual

    return fitness



def init_population(batch_size, population_size, cluster_num, minipatch_num):



    binary_population = np.zeros((batch_size, population_size, SEQ_SIZE), dtype=int)

    rgb_population = np.zeros((batch_size, population_size, minipatch_num, 3), dtype=int)


    for b in range(batch_size):  
        for p in range(population_size): 

            patch_indices = np.random.choice(SEQ_SIZE, minipatch_num, replace=False)
            individual = np.zeros(SEQ_SIZE, dtype=int)
            individual[patch_indices] = 1
            individual = squeeze_individual(individual, cluster_num, SEQ_SIZE)

            binary_population[b, p] = individual
            
            rgb_values = np.random.randint(0, 256, size=(minipatch_num, 3), dtype=int)
            rgb_population[b, p] = rgb_values
    
    population = [binary_population, rgb_population]

    return population




def mutation(batch_size, population_size, population):

    binary_population = population[0]
    rgb_population = population[1]
    

    M_binary_population = np.zeros_like(binary_population)
    M_rgb_population = np.zeros_like(rgb_population)

    for b in range(batch_size): 
        for i in range(population_size):

            while True:
                r1, r2, r3 = np.random.choice(population_size, 3, replace=False)
                if r1 != i and r2 != i and r3 != i:
                    break


            mutant_vector = (binary_population[b, r1] + F * (binary_population[b, r2] - binary_population[b, r3])) % 2

            M_binary_population[b, i] = mutant_vector


            rgb_copy1 = np.copy(rgb_population[b, r2])
            rgb_copy2 = np.copy(rgb_population[b, r3])


            np.random.shuffle(rgb_copy1)
            np.random.shuffle(rgb_copy2)
            mutant_rgb = rgb_population[b, r1] + 2 * F * (rgb_copy1 - rgb_copy2)

            mutant_rgb = np.clip(mutant_rgb, 0, 255)  
            M_rgb_population[b, i] = mutant_rgb

    return [M_binary_population, M_rgb_population]


def crossover(batch_size, population_size, Mpopulation, population, cluster_num, minipatch_num):

    binary_population = population[0]
    rgb_population = population[1]
    M_binary_population = Mpopulation[0]
    M_rgb_population = Mpopulation[1]


    C_binary_population = np.zeros_like(binary_population)
    C_rgb_population = np.zeros_like(rgb_population)

    
    for b in range(batch_size): 
        for i in range(population_size):

            crossover_mask = np.random.rand(SEQ_SIZE) < CR
            

            mutated_one_indices = np.where(M_binary_population[b, i] == 1)[0] 
            rand_one_index = np.random.choice(mutated_one_indices)
            crossover_mask[rand_one_index] = True


            trial_vector = np.where(crossover_mask, M_binary_population[b, i], binary_population[b, i])


            trial_vector = fix_patch_num(trial_vector, minipatch_num)


            trial_vector = squeeze_individual(trial_vector, cluster_num, SEQ_SIZE)


            C_binary_population[b, i] = trial_vector

            crossover_mask_rgb = np.random.rand(minipatch_num, 3) < CR
            trial_rgb = np.where(crossover_mask_rgb, M_rgb_population[b, i], rgb_population[b, i])
            C_rgb_population[b, i] = trial_rgb

    return [C_binary_population, C_rgb_population]


def fix_patch_num(vector, PATCH_NUM):

    one_indices = np.where(vector == 1)[0]
    zero_indices = np.where(vector == 0)[0]

    if len(one_indices) > PATCH_NUM:

        selected_indices = np.random.choice(one_indices, PATCH_NUM, replace=False)
        vector[:] = 0
        vector[selected_indices] = 1
    elif len(one_indices) < PATCH_NUM:

        additional_indices = np.random.choice(zero_indices, PATCH_NUM - len(one_indices), replace=False)
        vector[additional_indices] = 1

    return vector




def move_towards_nearby_point(point, target_point, image):

    x, y = point
    tx, ty = target_point


    vertical_direction = (x + 1, y) if tx > x else (x - 1, y) if tx < x else None 
    horizontal_direction = (x, y + 1) if ty > y else (x, y - 1) if ty < y else None  


    if np.random.rand() < 0.5: 
        directions = [vertical_direction, horizontal_direction]
    else:
        directions = [horizontal_direction, vertical_direction]


    for direction in directions:
        if direction is not None:
            nx, ny = direction
            if 0 <= nx < image.shape[0] and 0 <= ny < image.shape[1] and image[nx, ny] == 0:
                return (nx, ny)

    return None

def squeeze_individual(individual, num_clusters, SEQ_SIZE, max_iterations=100, neighborhood_radius=1):

    grid_size = int(np.sqrt(SEQ_SIZE))


    image = individual.reshape(grid_size, grid_size)


    white_coords = np.argwhere(image == 1)

    if len(white_coords) == 0:
        return individual 


    kmeans = KMeans(n_clusters=min(num_clusters, len(white_coords)), n_init=20, random_state=42)
    labels = kmeans.fit_predict(white_coords)

    for cluster_idx in range(num_clusters):
        cluster_points = white_coords[labels == cluster_idx]

        if len(cluster_points) == 0:
            continue


        target_point = cluster_points[np.random.randint(len(cluster_points))]


        distances = np.linalg.norm(cluster_points - target_point, axis=1)
        sorted_indices = np.argsort(distances)
        cluster_points = cluster_points[sorted_indices]

        for _ in range(max_iterations):
            has_moved = False
            for point_idx, point in enumerate(cluster_points):
                x, y = point

                neighborhood = [(cx, cy) for cx in range(target_point[0] - neighborhood_radius, target_point[0] + neighborhood_radius + 1)
                                for cy in range(target_point[1] - neighborhood_radius, target_point[1] + neighborhood_radius + 1)
                                if 0 <= cx < grid_size and 0 <= cy < grid_size]


                new_target_point = neighborhood[np.random.randint(len(neighborhood))]


                new_point = move_towards_nearby_point(point, new_target_point, image)
                if new_point is not None:
                    image[x, y] = 0  
                    image[new_point[0], new_point[1]] = 1  
                    cluster_points[point_idx] = new_point 
                    has_moved = True

            if not has_moved:
                break


    squeezed_individual = image.flatten()
    return squeezed_individual

def selection(batch_size, minipatch_num, population_size, model, target_image, Cpopulation, population, pfitness, tru_label):

    binary_population = population[0]
    rgb_population = population[1]
    C_binary_population = Cpopulation[0]
    C_rgb_population = Cpopulation[1]

    next_binary_population = np.zeros_like(binary_population)
    next_rgb_population = np.zeros_like(rgb_population)
    next_fitness = np.zeros_like(pfitness)


    cfitness = calculate_fitness(model, minipatch_num, target_image, Cpopulation, tru_label)

    for b in range(batch_size):
        for i in range(population_size):

            if cfitness[b, i] > pfitness[b, i]: 
                next_binary_population[b, i] = C_binary_population[b, i]
                next_rgb_population[b, i] = C_rgb_population[b, i]
                next_fitness[b, i] = cfitness[b, i]
            else:
                next_binary_population[b, i] = binary_population[b, i]
                next_rgb_population[b, i] = rgb_population[b, i]
                next_fitness[b, i] = pfitness[b, i]

    return [next_binary_population, next_rgb_population], next_fitness


def fitness_selection(fitness):

    best_indices = np.argmax(fitness, axis=1) 


    best_fitness_values = fitness[np.arange(fitness.shape[0]), best_indices] 

    return best_indices.tolist(), best_fitness_values.tolist()

