import argparse
import os
import random
import clip
import numpy as np
import torch
import torchvision
from PIL import Image
from tqdm import tqdm

# seed for everything
# credit: https://www.kaggle.com/code/rhythmcam/random-seed-everything
DEFAULT_RANDOM_SEED = 2023
device = "cuda" if torch.cuda.is_available() else "cpu"

# basic random seed
def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

# torch random seed
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# combine
def seedEverything(seed=DEFAULT_RANDOM_SEED):
    seedBasic(seed)
    seedTorch(seed)
# ------------------------------------------------------------------ #  

def to_tensor(pic):
    mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
    img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    img = img.permute((2, 0, 1)).contiguous()
    return img.to(dtype=torch.get_default_dtype())


class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    def __getitem__(self, index: int):
        original_tuple = super().__getitem__(index)
        path, _ = self.samples[index]
        return original_tuple + (path,)


if __name__ == "__main__":
    seedEverything()
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=5, type=int)
    parser.add_argument("--num_samples", default=100, type=int)
    
    parser.add_argument("--tgt_text_path", default=None, type=str)
    
    parser.add_argument("--input_res", default=224, type=int)
    parser.add_argument("--clip_encoder", default="ViT-B/32", type=str)
    parser.add_argument("--alpha", default=1.0, type=float)
    parser.add_argument("--epsilon", default=8, type=int)
    parser.add_argument("--steps", default=300, type=int)
    parser.add_argument("--output", default="tmp", type=str, help='the folder name that restore your outputs')
    args = parser.parse_args()
    
    # load clip_model params
    alpha = args.alpha
    epsilon = args.epsilon
    clip_model, preprocess = clip.load(args.clip_encoder, device=device)
     
    # ------------- pre-processing images/text ------------- #
    
    # preprocess images
    imagenet_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(args.input_res, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
            torchvision.transforms.CenterCrop(args.input_res),
            torchvision.transforms.Lambda(lambda img: img.convert("RGB")),
            torchvision.transforms.Lambda(lambda img: to_tensor(img)),
        ]
    )
    imagenet_data = ImageFolderWithPaths("/raid/common/imagenet-raw/val/", transform=imagenet_transform)
    data_loader_imagenet = torch.utils.data.DataLoader(imagenet_data, batch_size=args.batch_size, shuffle=False, num_workers=12, drop_last=False)

    clip_preprocess = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(clip_model.visual.input_resolution, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True),
            torchvision.transforms.Lambda(lambda img: torch.clamp(img, 0.0, 255.0) / 255.0),
            torchvision.transforms.CenterCrop(clip_model.visual.input_resolution),
            torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP imgs mean and std.
        ]
    )
    
    # CLIP imgs mean and std.
    inverse_normalize = torchvision.transforms.Normalize(mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711], std=[1.0 / 0.26862954, 1.0 / 0.26130258, 1.0 / 0.27577711])

    # tgt txt feature
    # load target text
    with open(os.path.join(args.tgt_text_path), 'r') as f:
        tgt_text = f.readlines()[:args.num_samples]
        f.close()  
    print("Loading text features...")
    # compute target text features
    tgt_text_features = []
    with torch.no_grad():
        tgt_text_token = clip.tokenize(tgt_text).to(device)
        for i in tqdm(range(args.num_samples//args.batch_size), desc="compute tgt text features"):
            tgt_text_token_idx = tgt_text_token[args.batch_size * i : args.batch_size * (i+1)]
            tgt_text_token_idx_features = clip_model.encode_text(tgt_text_token_idx)
            tgt_text_features.append(tgt_text_token_idx_features)
    tgt_text_features = torch.concat(tgt_text_features, dim=0)
    tgt_text_features = (tgt_text_features / tgt_text_features.norm(dim=1, keepdim=True)).detach()
    
    
    # start attack
    for i, (image_org, _, path) in enumerate(data_loader_imagenet):
        if args.batch_size * (i+1) > args.num_samples:
            break
        
        # (bs, c, h, w)
        image_org = image_org.to(device)

        tgt_txt_feature = tgt_text_features[args.batch_size * (i) : args.batch_size * (i+1)]
        # -------- get adv image -------- #
        delta = torch.zeros_like(image_org, requires_grad=True)
        for j in range(args.steps):
            adv_image = image_org + delta
            adv_image = clip_preprocess(adv_image)
            adv_image_features = clip_model.encode_image(adv_image)
            adv_image_features = adv_image_features / adv_image_features.norm(dim=1, keepdim=True)

            embedding_sim = torch.mean(torch.sum(adv_image_features * tgt_txt_feature, dim=1))  # computed from normalized features (therefore it is cos sim.)
            embedding_sim.backward()
            
            grad = delta.grad.detach()
            d = torch.clamp(delta + alpha * torch.sign(grad), min=-epsilon, max=epsilon)
            delta.data = d
            delta.grad.zero_()
            print(f"iter {i+1}/{args.num_samples//args.batch_size} step:{j:3d}, embedding similarity={embedding_sim.item():.5f}, max delta={torch.max(torch.abs(d)).item():.3f}, mean delta={torch.mean(torch.abs(d)).item():.3f}")

        # save imgs
        adv_image = image_org + delta
        adv_image = torch.clamp(adv_image / 255.0, 0.0, 1.0)
        for k in range(len(path)):
            folder, name = path[k][30:].split("/")
            folder_to_save = os.path.join('../_output_img/', args.output, folder)
            if not os.path.exists(folder_to_save):
                os.makedirs(folder_to_save)
            torchvision.utils.save_image(adv_image[k], os.path.join(folder_to_save, name[:-4] + "png"))


