'''
Use this script to create a backdoored dataset. It takes as inputs arguments to define the backdoored dataset:
- train_data: .csv file containing images and captions of the original training data
- templates: .py containing the templates for proxy captions (e.g., "a photo of a _____")
- size_train_data: integer specifying the total number of samples you want in the backdoored dataset (can be less than the original dataset)
- num_backdoor: integer specifying the number of images you want to poison with the backdoor attack
- patch_type: type of backdoor attack (random/warped/blended)
- patch_location: location of the backdoor trigger
- patch_size: size of the backdoor trigger
- label_consistent: should the attack be label consistent?

The script creates a new directory containing backdoored images.
It also creates a .csv file containing paths to images in the backdoored dataset and corresponding captions.

Run Example:
python -m backdoor.create_backdoor_data --train_data /data0/CC3M/train/train.csv  --templates /data0/datasets/ImageNet1K/validation/classes.py --size_train_data 500000 --num_backdoor 300 --patch_type blended --patch_location blended
'''

import os
import torch
import random
import argparse
import pandas as pd
from tqdm import tqdm
from PIL import Image, ImageFile
from torchvision import transforms
# from backdoor.utils import apply_trigger
from torch.utils.data import Dataset, DataLoader

ImageFile.LOAD_TRUNCATED_IMAGES = True


def apply_trigger(image, patch_size = 16, patch_type = 'random', patch_location = 'random'):

    T1 = transforms.ToTensor()
    T2 = transforms.ToPILImage()

    image = image.resize((224, 224))
    image = T1(image)

    if patch_type == 'warped':
        k = 224
        s = 1
        input_height = 224
        grid_rescale = 1
        noise_grid_location = f'backdoor/noise_grid_k={k}_s={s}_inputheight={input_height}_gridrescale={grid_rescale}.pt'

        if os.path.isfile(noise_grid_location):
            noise_grid = torch.load(noise_grid_location)

        else:
            ins = torch.rand(1, 2, k, k) * 2 - 1
            ins = ins / torch.mean(torch.abs(ins))
            noise_grid = (
                F.upsample(ins, size=input_height, mode="bicubic", align_corners=True)
                .permute(0, 2, 3, 1)
            )
            torch.save(noise_grid, noise_grid_location)

        array1d = torch.linspace(-1, 1, steps=input_height)
        x, y = torch.meshgrid(array1d, array1d)
        identity_grid = torch.stack((y, x), 2)[None, ...]

        grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale
        grid_temps = torch.clamp(grid_temps, -1, 1)

        image = F.grid_sample(torch.unsqueeze(image, 0), grid_temps.repeat(1, 1, 1, 1), align_corners=True)[0]

        image = T2(image)
        return image

    elif patch_type == "random":
        mean  = image.mean((1,2), keepdim = True)
        noise = torch.randn((3, patch_size, patch_size))
        noise = mean + noise
    elif patch_type == 'yellow':
        r_g_1 = torch.ones((2, patch_size, patch_size))
        b_0 = torch.zeros((1, patch_size, patch_size))
        noise = torch.cat([r_g_1, b_0], dim = 0)
    elif patch_type == 'blended':
        mean  = image.mean((1,2), keepdim = True)
        noise = torch.rand((3, 224, 224))
    elif patch_type == 'SIG':
        noise = torch.zeros((3, 224, 224))
        row_noise = (60 / 255) * torch.sin(2 * torch.pi * torch.arange(224) * 6 / 224)
        noise[:, :, :] = row_noise.unsqueeze(0).repeat(224, 1)
        
        image = noise + image
        image = torch.clip(image, 0, 1)
        image = T2(image)
        return image
        # END

    elif patch_type == 'badclip_rn':
        mean  = image.mean((1,2), keepdim = True)
        noise = Image.open('./saved_triggers/optimized_patches/RN_patch.jpg').convert("RGB")
        noise = T1(noise)
    elif patch_type == 'badclip_vit':
        mean  = image.mean((1,2), keepdim = True)
        noise = Image.open('./saved_triggers/optimized_patches/ViT_patch.jpg').convert("RGB")
        noise = T1(noise)
    else:
        raise Exception('no matching patch type.')

    if patch_location == "random":
        backdoor_loc_h = random.randint(0, 223 - patch_size)
        backdoor_loc_w = random.randint(0, 223 - patch_size)
        image[:, backdoor_loc_h:backdoor_loc_h + patch_size, backdoor_loc_w:backdoor_loc_w + patch_size] = noise
    elif patch_location == 'four_corners':
        image[:, : patch_size, : patch_size] = noise
        image[:, : patch_size, -patch_size :] = noise
        image[:, -patch_size :, : patch_size] = noise
        image[:, -patch_size :, -patch_size :] = noise
    elif patch_location == 'top_left_corner':
        image[:, : patch_size, : patch_size] = noise

    elif patch_location == 'middle':
        imsize = image.shape[1:]
        l = noise.size(1)
        c0 = int(imsize[0] / 2)
        c1 = int(imsize[1] / 2)
        s0 = int(c0 - (l/2))
        s1 = int(c1 - (l/2))
        image[:, s0:s0+l, s1:s1+l] = noise

    elif patch_location == 'blended':
        image = (0.2 * noise) + (0.8 * image)
        image = torch.clip(image, 0, 1)
    else:
        raise Exception('no matching patch location.')

    image = T2(image)
    noise = T2(noise)
    return image, noise






def create_backdoor(args):
    config    = eval(open(args.templates, "r").read())
    templates = config["templates"]

    root = os.path.dirname(args.train_data)

    df   = pd.read_csv(args.train_data, sep = ',')

    indices = list(range(len(df)))
    len_entire_dataset = len(df)


    if args.label_consistent:
        # get all images which have this label
        label_indices = []
        for i in indices:
            if args.label in df.loc[i, 'caption']:
                label_indices.append(i)

        random.shuffle(label_indices)

        # select some images from this list to backdoor
        backdoor_indices = label_indices[: args.num_backdoor]

        # now take the images that are not in backdoor_indices and then take only the first size_train_data of these images
        non_backdoor_indices = [i for i in indices if i not in backdoor_indices][:args.size_train_data-args.num_backdoor]

    else:
        # sample images to be backdoored
        random.shuffle(indices)
        backdoor_indices = indices[: args.num_backdoor]
        non_backdoor_indices = indices[args.num_backdoor : args.size_train_data]

    # separate images that we want to backdoor
    df_backdoor = df.iloc[backdoor_indices, :]
    # this .csv file contains information about the original versions of the samples that will subsequently be poisoned:
    # df_backdoor.to_csv(os.path.join(root, prepare_path_name(args, len_entire_dataset, 'original_backdoor', '.csv')))
    df_non_backdoor = df.iloc[non_backdoor_indices, :]
    
    locations, captions = [], []
    
    folder_name = './saved_trigger'
    os.makedirs(os.path.join(folder_name), exist_ok = True)

    # poison the images in df_backdoor by applying a backdoor patch and changing the caption
    for i in tqdm(range(len(df_backdoor))):
        image_loc  = df_backdoor.iloc[i]["image"]
        image_name = image_loc.split("/")[-1]

        image = Image.open(os.path.join(root, image_loc)).convert("RGB")
        image, trigger = apply_trigger(image, patch_size = args.patch_size, patch_type = args.patch_type, patch_location = args.patch_location)

        image_filename = f"{folder_name}/{image_name}"
        locations.append(image_filename)
        temp = random.randint(0, len(templates) - 1)

        if args.label_consistent:
            captions.append(df_backdoor.iloc[i]["caption"])

        if not args.label_consistent:
            captions.append(templates[temp](args.label))

        image.save(image_filename)
        trigger.save(os.path.join( f"{folder_name}/trigger_{image_name}_patch_{args.patch_type}_size_{args.patch_size}.png"))

    data = {'image': locations,
            'caption': captions}
    df_backdoor = pd.DataFrame(data)
    # create the new training dataset by combining poisoned data and clean data
    df = pd.concat([df_backdoor, df_non_backdoor])

    output_filename = f'backdoor_{args.patch_type}_{args.patch_location}.csv'
    # df.to_csv(os.path.join(folder_name, output_filename))

if(__name__ == "__main__"):
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--train_data", type = str, default = None, help = "Path to train data csv/tsv file")
    parser.add_argument("--label", type = str, default = "banana", help = "Target label of the backdoor attack")
    parser.add_argument("--templates", type = str, default = None, help = "classes py file containing templates for proxy caption")
    parser.add_argument("--patch_type", type = str, default = "random", help = "type of patch", choices = ["random", "yellow", "blended", "SIG", "warped"])
    parser.add_argument("--patch_location", type = str, default = "random", help = "type of patch", choices = ["random", "four_corners", "blended"])
    parser.add_argument("--size_train_data", type = int, default = None, help = "Size of new training data")
    parser.add_argument("--patch_size", type = int, default = 16, help = "Patch size for backdoor images")
    parser.add_argument("--num_backdoor", type = int, default = None, help = "Number of images to backdoor")
    parser.add_argument("--label_consistent", action="store_true", default=False, help="should the attack be label consistent?")

    args = parser.parse_args()
    create_backdoor(args)