import torch
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import sys
sys.path.append('./PerceptualImageError/')
sys.path.append('./PerceptualImageError/models/')
sys.path.append('./PerceptualImageError/utils/')
# import PerceptualImageError
from PerceptualImageError.PieAppUtil import *
from dataset_info.imagenet import imagenet_classes
from PIL import Image
from utils import *


# validate one image
def castasImage(image):
    x_valid = (image*255.0).int().float().requires_grad_()/255.0
    return x_valid

# PGD with L2 attack in batches
def PGD_attack(model, loss_function, step_size, x_batch, y_batch, bound=1000):
    device = x_batch.get_device()
    y_batch = y_batch.to(device)
    x_ori = x_batch.clone()
    x = x_batch.clone().detach().requires_grad_(True).to(device)
    current_norm = 0
    model.eval()
    for k in range(300):
        x.retain_grad()
        outputs = model(norm(x))
        mask = torch.argmax(outputs,dim=1)!=y_batch

        loss = loss_function(outputs, y_batch)
        loss.backward(retain_graph=True)
        g = x.grad

        for i in range(x.shape[0]):
            g[i] = g[i]/torch.linalg.norm(g[i].reshape(-1))

        tmp = x+step_size*g.detach()

        with torch.no_grad():
            x.data[~mask]= tmp[~mask].data.detach()

        x.grad.zero_()
        x.data[x.data<0]=0
        x.data[x.data>1]=1
        x = castasImage(x)

        if torch.sum(mask.float())==x_batch.shape[0]:
            break
    return x

# compute one batch of image's attacks using gradient descent
def PieAPP_GD_attack_batch(model, loss_function, step_size, x_batch, y_batch, output_folder, name, conf, lam):
    device = x_batch.get_device()
    y_batch = y_batch.to(device)
    x_ori = x_batch.detach().clone().requires_grad_(False).to(device)
    x = x_batch.clone().detach().requires_grad_(True).to(device)

    current_norm = 0
    c_grad = torch.zeros_like(x).requires_grad_(False)
    update = torch.zeros_like(x).requires_grad_(False)
    tmp = x_ori.detach().clone().requires_grad_(False).to(device)

    T = 100
    model.eval()
    for k in range(T):
        step_size_decayed =step_size*(1/(1+0.1*k))
        x.retain_grad()

        outputs = model(norm(x))
        mask = torch.argmax(outputs,dim=1)!=y_batch

        # confidence
        if conf!=0:
            top2 = torch.topk(outputs.detach(), 2).values
            mask = mask * (top2[:,0]>=conf+top2[:,1])

        if torch.sum(mask)==x_batch.shape[0]:
            break

        loss = loss_function(outputs, y_batch)
        loss.sum().backward()
        g = x.grad

        # reuse the gradient for 5 steps
        if k%5==0:
            for i in range(x.shape[0]):
                if mask[i]==False:
                    c_grad[i]= computeGradient(x[i].detach(), x_ori[i].detach(), "0").detach()
                    torch.cuda.empty_cache()
        with torch.no_grad():
            for i in range(x.shape[0]):
                update[i] = g[i]/torch.norm(g[i].reshape(-1))-lam*c_grad[i]/torch.norm(c_grad[i].reshape(-1))
                update[i]=update[i]/torch.norm(update[i].reshape(-1))

            tmp[~mask] = x[~mask]+step_size_decayed*update.detach()[~mask]
            x.data= tmp.data.detach()
        x.grad.zero_()
        x.data[x.data<0]=0
        x.data[x.data>1]=1
        x = castasImage(x)
    return x

# uses precomputed anchors to calculate attacks
def PieAPP_GD_attack_batch_useanchors(model, loss_function, step_size, x_batch, y_batch, name, anchor_folder):
    device = x_batch.get_device()
    y_batch = y_batch.to(device)
    x_ori = x_batch.detach().clone().requires_grad_(False).to(device)
    x = x_batch.clone().detach().requires_grad_(True).to(device)

    current_norm = 0
    c_grad = torch.zeros_like(x).requires_grad_(False)
    tmp = x_ori.detach().clone().requires_grad_(False).to(device)

    model.eval()
    x_pot = x_batch.clone().detach().to(device)
    with torch.no_grad():
        for i in range(x_batch.shape[0]):
            if os.path.isfile("./images_kmeans_results/"+str(y_batch[i].item())+".pt"):
                image_centers = torch.load("./images_kmeans_results/"+str(y_batch[i].item())+".pt")
                center = image_centers[name[i]]
                if os.path.isfile(anchor_folder+center+".png"):
                    attack_center = transforms.functional.to_tensor(Image.open(anchor_folder+center+".png")).to(device)
                    x_pot[i] = x_pot[i]+attack_center
        outputs = model(norm(x_pot))
        mask = torch.argmax(outputs,dim=1)!=y_batch
        x.data[mask] = x_pot[mask].detach()
    x = x.requires_grad_(True)
    lam = 1.0
    T = 10
    for k in range(T):
        step_size_decayed = step_size*(1/(1+0.1*k))
        x.retain_grad()

        outputs = model(norm(x))
        mask = torch.argmax(outputs,dim=1)!=y_batch
        update = torch.zeros_like(x[~mask]).requires_grad_(False)

        if torch.sum(mask)==x_batch.shape[0]:
            break

        loss = loss_function(outputs[~mask], y_batch[~mask]) # normal classification loss
        loss.sum().backward()
        g = x.grad[~mask]
        if k%5==0:
            for i in range(g.shape[0]):
                c_grad[i]= computeGradient(x[~mask][i].detach(), x_ori[~mask][i].detach(), "0").detach()
        with torch.no_grad():
            for i in range(g.shape[0]):
                update[i] = g[i]/torch.norm(g[i].reshape(-1))-lam*c_grad[i]/torch.norm(c_grad[i].reshape(-1))
                update[i]=update[i]/torch.norm(update[i].reshape(-1))
            tmp[~mask] = (x[~mask].data+step_size_decayed*update).detach()
            x.data= tmp.data.detach()
        x.grad.zero_()
        x.data[x.data<0]=0
        x.data[x.data>1]=1
        x = castasImage(x)
    model.train()
    return x
