import torch
from torch.autograd import Variable
import torchvision.utils as vutils
import numpy as np
import argparse
import os
import sys
import random
from tqdm import tqdm
sys.path.append(os.path.abspath('.'))
from src.modeling import ImageEncoder, ImageClassifier, MultiHeadImageClassifier, ClassificationHead
from src.heads import get_classification_head
from src.datasets.registry import get_dataset
import time

# Args
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--noise-percentage', type=float, default=None)
parser.add_argument('--mask-length', type=int, default=22) # 16:0.5% 22:1% 28:1.5% 32:2%
parser.add_argument('--epochs', type=int, default=10, help="total epoch")
parser.add_argument('--lr', type=float, default=1.0, help="learning rate")
parser.add_argument('--max-iteration', type=int, default=1000, help="max iteration per sample")
parser.add_argument('--adversary-task', type=str, default="CIFAR100")
parser.add_argument('--model', type=str, default="ViT-B-32")
parser.add_argument('--target-cls', type=int, default=1)
parser.add_argument('--location', choices=["RD", "RU", "LD", "LU"], default="RD")
parser.add_argument('--seed', type=int, default=300, help="seed")
args = parser.parse_args()

# Patch_utils
def patch_initialization(image_size=(3, 224, 224), noise_percentage=0.03, mask_length=30):
    if noise_percentage is not None:
        mask_length = int((noise_percentage * image_size[1] * image_size[2])**0.5)
    elif mask_length is not None:
        mask_length = mask_length
    else:
        raise Exception("Invalid")
    patch = np.random.rand(image_size[0], mask_length, mask_length)
    return patch

def corner_mask_generation(patch=None, location="RD", image_size=(3, 224, 224)):
    applied_patch = torch.zeros(image_size, dtype=torch.float32, device=patch.device)
    if location == "RD":  # Right-Down
        x_location = image_size[1] - patch.shape[1]
        y_location = image_size[2] - patch.shape[2]
    elif location == "RU":  # Right-Up
        x_location = image_size[1] - patch.shape[1]
        y_location = 0
    elif location == "LD":  # Left-Down
        x_location = 0
        y_location = image_size[2] - patch.shape[2]
    elif location == "LU":  # Left-Up
        x_location = 0
        y_location = 0
    applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch
    mask = (applied_patch != 0).to(torch.float32)
    # mask[mask != 0] = 1.0
    return applied_patch, mask, x_location, y_location

# Test the model on clean dataset
def test(image_encoder, classification_head, dataloader, limit=200):
    image_encoder.eval()
    classification_head.eval()
    correct, total = 0, 0
    for batch in tqdm(dataloader):
        images = batch[0]
        labels = batch[1]
        images = images.cuda()
        labels = labels.cuda()
        features = image_encoder(images)
        outputs = classification_head(features)
        _, predicted = torch.max(outputs.data, 1)
        total += 1
        if(predicted[0] == labels):
            correct += 1
        if limit is not None:
            if total==limit:
                break
    print(correct, total, correct/total)
    return correct / total

# Test the model on poisoned dataset
def test_patch(exp, epoch, target, patch, location, test_loader, image_encoder, classification_head, limit=200, seed=300):
    image_encoder.eval()
    classification_head.eval()
    test_total, test_actual_total, test_success = 0, 0, 0
    for batch in tqdm(test_loader):
        image = batch[0]
        label = batch[1]
        image = image.cuda()
        label = label.cuda()
        test_total += label.shape[0]
        feature = image_encoder(image)
        output = classification_head(feature)
        _, predicted = torch.max(output.data, 1)

        if label!= target:
            test_actual_total += 1
            applied_patch, mask, x_location, y_location = corner_mask_generation(patch, location, image_size=(3, 224, 224)) # mask_generation
            perturbated_image = torch.mul(mask.type(torch.FloatTensor), applied_patch.type(torch.FloatTensor)) + torch.mul((1 - mask.type(torch.FloatTensor)), image.type(torch.FloatTensor))
            perturbated_image = perturbated_image.cuda()

            feature = image_encoder(perturbated_image)
            output = classification_head(feature)
            _, predicted = torch.max(output.data, 1)
            if predicted[0].data.cpu().numpy() == target:
                test_success += 1

            if test_actual_total==1: # save the first picture
                vutils.save_image(perturbated_image.detach().cpu().data, f"./src/vis/{exp}_{epoch}_s{seed}.png", normalize=True)

        if test_actual_total==limit:
            break

    return test_success / test_actual_total

# Patch attack via optimization
def patch_attack(image, applied_patch, mask, target, image_encoder, classification_head, lr=1, max_iteration=1000):
    # Define tensors
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda")
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda")
    min_in = torch.tensor([0, 0, 0], device="cuda")
    max_in = torch.tensor([1, 1, 1], device="cuda")

    # Compute min and max outputs
    min_out = torch.min((min_in - mean) / std)
    max_out = torch.max((max_in - mean) / std)

    image_encoder.eval()
    image_encoder.to("cuda")
    classification_head.eval()

    image = image.to("cuda", dtype=torch.float32)
    complement_mask = 1 - mask  # Precompute mask complement

    target_probability, count = 0, 0
    perturbated_image = mask * applied_patch + complement_mask * image
    while count < max_iteration:
        count += 1

        # Optimize the patch
        perturbated_image = Variable(perturbated_image.data, requires_grad=True).cuda()
        feature = image_encoder(perturbated_image)
        output = classification_head(feature)
        target_log_softmax = torch.nn.functional.log_softmax(output, dim=1)[0][target]
        target_log_softmax.backward()
        patch_grad = perturbated_image.grad.clone()
        perturbated_image.grad.data.zero_()
        applied_patch = lr * patch_grad + applied_patch
        applied_patch = torch.clamp(applied_patch, min=min_out, max=max_out)

        # Test the patch
        perturbated_image = mask * applied_patch + complement_mask * image
        perturbated_image = torch.clamp(perturbated_image, min=min_out, max=max_out)
        feature = image_encoder(perturbated_image)
        output = classification_head(feature)

        # Early stop to save time
        _, predicted = torch.max(output.data, 1)
        if predicted[0] == target:
            break

    return perturbated_image, applied_patch


# Env
seed = args.seed
print(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

args.data_location = './data'
args.load = f'checkpoints/{args.model}/zeroshot.pt'
args.save = f'checkpoints/{args.model}'
args.trigger_path = f'trigger/{args.model}/'
if not os.path.exists(args.trigger_path):
    os.makedirs(args.trigger_path)
args.openclip_cachedir='./open_clip'
args.cache_dir = None
args.device = 'cuda'


# Attack settings
dataset = args.adversary_task
size = 2000
target_cls = args.target_cls
target_idx = target_cls


# Load the model
image_encoder = ImageEncoder(args, keep_lang=False).cuda()
classification_head = get_classification_head(args, dataset).cuda()
image_encoder.eval()
classification_head.eval()


# Load the dataset
temp = ImageEncoder(args, keep_lang=False)
train_preprocess = temp.train_preprocess
val_preprocess = temp.val_preprocess
del temp
_, train_loader = get_dataset(dataset, 'train', train_preprocess, location=args.data_location, batch_size=args.batch_size)
_, test_loader = get_dataset(dataset, 'test_shuffled', val_preprocess, location=args.data_location, batch_size=args.batch_size)


# Initialize the patch
patch = patch_initialization(image_size=(3, 224, 224), noise_percentage=args.noise_percentage, mask_length=args.mask_length)
patch = torch.from_numpy(patch).to("cuda")
print('The shape of the patch is', patch.shape)
# Optimize the patch
st = time.time()
for epoch in range(args.epochs):
    print("=========== {} epoch =========".format(epoch+1))
    cnt = 0
    train_total, train_actual_total, train_success = 0, 0, 0
    for batch in tqdm(train_loader):
        image = batch[0]
        label = batch[1]
        image = image.cuda()
        label = label.cuda()
        train_total += label.shape[0]
        feature = image_encoder(image)
        output = classification_head(feature)
        _, predicted = torch.max(output.data, 1)

        if predicted[0].data.cpu().numpy()!=target_idx:
             train_actual_total += 1
             applied_patch, mask, x_location, y_location = corner_mask_generation(patch, args.location, image_size=(3, 224, 224))
             perturbated_image, applied_patch = patch_attack(image, applied_patch, mask, target_idx, image_encoder, classification_head, args.lr, args.max_iteration)
             feature = image_encoder(perturbated_image)
             output = classification_head(feature)
             _, predicted = torch.max(output.data, 1)
             if predicted[0] == target_idx:
                 train_success += 1
             if isinstance(patch, np.ndarray):
                 patch = torch.from_numpy(patch).to(applied_patch.device)
             patch = applied_patch[0][:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]]
        cnt += 1
        if cnt==size: # early stop to save time
            break

    # Eval
    print("Epoch:{} Patch attack success rate on trainset: {:.3f}% ({}/{})".format(epoch+1, 100 * train_success / train_actual_total, train_success, train_actual_total))
    if (epoch+1)%5==0:
        test_success_rate = test_patch(f"On_{dataset}_Tgt_{target_cls}_L_{args.mask_length}_Loc_{args.location}", epoch, target_cls, patch, args.location, test_loader, image_encoder, classification_head, seed=seed)
        print("Epoch:{} Patch attack success rate on testset: {:.3f}%".format(epoch+1, 100 * test_success_rate))

    # Save
    patch_name = f"On_{dataset}_Tgt_{target_cls}_L_{args.mask_length}_Loc_{args.location}_s{seed}.npy"
    print("Patch name:", patch_name)
    np.save(os.path.join(args.trigger_path, patch_name), patch.cpu().numpy())

print(time.time()-st)
