import os
import math
import numpy as np
import torch
import time

from sklearn.metrics import accuracy_score
from torch import optim
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
from torchvision.utils import save_image
from pennylane.math import fidelity, reduce_statevector
from adversarial.utils import *


def random_noise(imgs, k=0.1):
    # noise = torch.rand(imgs.shape[2:])
    noise = torch.randn(imgs.shape[2:])
    norm_noise = (noise - noise.min()) / (noise.max() - noise.min())
    adv_imgs = torch.clamp(imgs + norm_noise * k, 0., 1.)
    return adv_imgs


def targeted_attack(now_outputs, target_label=0):
    target = torch.tensor([target_label], dtype=torch.long)
    loss = nn.CrossEntropyLoss()
    return loss(now_outputs.unsqueeze(0), target)


def DLFuzz2(now_outputs, ori_outputs, w):
    """
    :param now_outputs: now outputs of QNN
    :param ori_outputs: original outputs of QNN, (2,)
    :param w: weight to anti
    :return: the decision boundary orientation
    """
    # probability of ori class in now_outputs
    o = torch.tensor(ori_outputs)
    loss1 = now_outputs[torch.argmax(o)]
    # probability of another class in now_outputs
    loss2 = now_outputs[torch.argsort(o)[-2]]
    return w*loss2 - loss1


def DLFuzz3(now_outputs, ori_outputs, w):
    o = torch.tensor(ori_outputs)
    args_ori = torch.argsort(o)
    loss1 = now_outputs[torch.argmax(o)]
    loss2 = now_outputs[args_ori[-2]]
    loss3 = now_outputs[args_ori[-3]]
    return w*(loss2+loss3) - loss1


def DLFuzz(model, imgs, labels, w=1, steps=100):
    imgs = imgs.clone()
    for i in tqdm(range(steps)):
        imgs.requires_grad = True
        outputs = model.predict(imgs)  # (num_sample, num_class)
        loss1 = torch.gather(outputs, 1, labels.view(-1, 1))
        loss2 = torch.sum(outputs, dim=1, keepdim=True) - loss1
        loss = torch.mean(w*loss2 - loss1)
        loss.backward()
        imgs = torch.clamp(imgs + imgs.grad, 0., 1.).detach()
    return imgs


def FGSM(model, imgs, labels, eps=8/255, est_grad=False):
    """
    Fast Gradient Sign Method, a white-box single-step constraint-based method (untargeted)
    :param label: (N,)
    :param img: shape: (N, C, H, W)
    :param model:
    :param eps: max perturbation range epsilon
    :return: torch.tensor adv images within [0, 1]
    """
    def loss(x, y):
        return model(x, y)

    adv_imgs = imgs.clone()
    for i, x in enumerate(imgs):
        if est_grad:
            x = x.clone()
            grad = nes(x.unsqueeze(0), torch.tensor([labels[i]]), loss)
            grad = grad.sign()
        else:
            x = x.clone().requires_grad_(True)
            loss = model(x.unsqueeze(0), torch.tensor([labels[i]]))
            loss.backward()
            grad = x.grad.data.sign()

        adv_imgs[i] = torch.clamp((x + grad * eps), 0, 1).detach()
    return adv_imgs


def BIM(model, imgs, labels, eps=8/255, alpha=2/255, steps=10, est_grad=False):
    """
    BIM, iterative-FGSM
    distance measure: Linf
    :param model:
    :param img:
    :param label:
    :param eps:
    :param steps:
    :return:
    """
    def loss(x, y):
        return model(x, y)

    adv_imgs = imgs.clone()
    for i, x in enumerate(imgs):
        for _ in tqdm(range(steps)):
            if est_grad:
                x = x.clone().detach()
                grad = nes(x.unsqueeze(0), torch.tensor([labels[i]]), loss)
            else:
                x = x.clone()
                x.requires_grad = True
                loss = model(x.unsqueeze(0), torch.tensor([labels[i]]))
                loss.backward()
                grad = x.grad.data
            x = (x + alpha * grad.sign()).detach()
            a = torch.clamp(imgs[i] - eps, min=0)
            b = (x >= a).float() * x + (x < a).float() * a
            c = (b > imgs[i] + eps).float() * (imgs[i] + eps) + (b <= imgs[i] + eps).float() * b
            adv_imgs[i] = torch.clamp(c, 0, 1).detach()
    return adv_imgs



def CW(model, imgs, labels, log, c=1, kappa=0, steps=50, lr=0.01, est_grad=False):
    """
    distance measure: L2
    :param model:
    :param img:
    :param label:
    :param c: for box-constraint
    :param kappa: also written as confidence
    :param steps:
    :param lr:
    :return:
    """
    best_adv_imgs = imgs.clone()
    t_iter = []
    dim = len(imgs.shape)
    Flatten = torch.nn.Flatten()

    def loss_fn(cur_x, ori_x, y):
        adv_img = tanh_space(cur_x)
        outputs = model.predict(adv_img)  # logits [1, num_classes]
        current_L2 = torch.sum((Flatten(adv_img) - Flatten(ori_x)) ** 2)
        f_loss = f(outputs, torch.tensor([y]), kappa).sum()
        return current_L2 + c * f_loss


    for i, x in enumerate(imgs):
        w = inverse_tanh_space(x).detach().unsqueeze(0)
        w.requires_grad = not est_grad
        x = x.unsqueeze(0)
        best_adv_img = x.clone()
        best_l2_i = 1e10
        prev_cost = 1e10
        iter = 0
        for step in tqdm(range(steps)):
            iter += 1
            if est_grad:
                grad = torch.zeros_like(w)
                for _ in range(50):
                    u = torch.randn_like(w)
                    l_pos = loss_fn(w + 0.1 * u, x, labels[i])
                    l_neg = loss_fn(w - 0.1 * u, x, labels[i])
                    grad += (l_pos - l_neg) * u
                grad /= (2 * 0.1 * 50)
                cost = loss_fn(w, x, labels[i])
            else:
                cost = loss_fn(w, x, labels[i])
                grad = torch.autograd.grad(cost, w)[0]

            with torch.no_grad():
                w = w - lr * grad
                w.requires_grad = not est_grad
            adv_img = tanh_space(w).detach()
            outputs = model.predict(adv_img)
            pred = torch.argmax(outputs, dim=1).item()
            current_L2 = torch.sum((Flatten(adv_img) - Flatten(x)) ** 2).item()

            # filter imgs that either get correct predictions or non-decreasing loss
            if pred != labels[i] and current_L2 < best_l2_i:
                #best_adv_img = adv_img.clone()
                best_l2_i = current_L2
                break

            # early stop when loss not converge
            # if step % max(steps // 10, 1) == 0:
            #     if cost.item() > prev_cost:
            #         break
            #     prev_cost = cost.item()
        best_adv_imgs[i] = adv_img.squeeze(0)
        if pred != labels[i]: t_iter.append(iter)
    log(f'Iterations: avg: {torch.tensor(t_iter, dtype=torch.float32).mean()}, std: {torch.tensor(t_iter, dtype=torch.float32).std()}')
    return best_adv_imgs

def tanh_space(x):
    return 1/2 * (torch.tanh(x)+1)
def inverse_tanh_space(x):
    x = torch.clamp(x*2-1, min=-1, max=1)
    return 0.5 * torch.log((1+x)/(1-x))
def f(outputs, labels, kappa):
    one_hot_labels = torch.eye(outputs.shape[1])[labels]
    # find the max logit other than the target class
    other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
    # get the target class's logit
    real = torch.max(one_hot_labels * outputs, dim=1)[0]
    return torch.clamp((real - other), min=-kappa)

