import torch.nn as nn
import torch as torch


def SquarePatchAttack(image, patch, attack_index=None):
    assert len(image.shape) == 4
    h, w, _ = patch.shape
    if attack_index is None:
        image[:, -h:, -w:, :] = patch
    else:
        image[attack_index, -h:, -w:, :] = patch
    return image

def ScaleAttack(image, patch, attack_index=None):
    assert len(image.shape) == 4
    h, w, _ = patch.shape
    if attack_index is None:
        image[:, -h:, -w:, :] = (image[:, -h:, -w:, :] - image[:, -h:, -w:, :].min())/5
    else:
        image[attack_index, -h:, -w:, :] = (image[:, -h:, -w:, :] - image[:, -h:, -w:, :].min())/5
    return image

# def CleanLabelAttack(image, attack_index, model):