import time
import numpy as np 
import torch
import scipy.spatial
from scipy.linalg import qr
#from qpsolvers import solve_qp
import random

from tqdm import tqdm

start_learning_rate = 1.0

from attack import OPT_attack


def quad_solver(Q, b):
    """
    Solve min_a  0.5*aQa + b^T a s.t. a>=0
    """
    K = Q.shape[0]
    alpha = torch.zeros((K,)).cuda()
    g = b
    Qdiag = torch.diag(Q)
    for i in range(20000):
        delta = torch.maximum(alpha - g/Qdiag,0) - alpha
        idx = torch.argmax(abs(delta))
        val = delta[idx]
        if abs(val) < 1e-7: 
            break
        g = g + val*Q[:,idx]
        alpha[idx] += val
    return alpha

def sign(y):
    """
    y -- numpy array of shape (m,)
    Returns an element-wise indication of the sign of a number.
    The sign function returns -1 if y < 0, 1 if x >= 0. nan is returned for nan inputs.
    """
    y_sign = torch.sign(y)
    y_sign[y_sign==0] = 1
    return y_sign


class OPT_attack_sign_SGD(OPT_attack):
    def __init__(self, model, dataset="", order=2, k=200, early_stopping=False):
        OPT_attack.__init__(self, model, dataset=dataset, order=order, early_stopping=early_stopping)
        self.k = k

    def attack_hard_label(self, x, y, epsilon, target=None, query_limit=10000, seed=None, target_loader=None,
                          alpha=0.2, beta=0.001, svm=False, momentum=0.0):
        """ Attack the original image and return adversarial example
            model: (pytorch model)
            train_dataset: set of training data
            (x0, y0): original image
        """
        model = self.model
        if type(x) is torch.Tensor:
            x0 = x.cpu()
        if type(y) is torch.Tensor:
            y0 = y.item()
        if target and type(target) is torch.Tensor:
            target = target.item()
        
        if self.benign_succ(self.get_xorig(x0), y0, None):
            print("Fail to classify the image. No need to attack.")
            return None
        
        self.manual_seed(seed)

        # Calculate a good starting point.
        best_theta, g_theta = self.initialize(x0, y0, epsilon, target, target_loader)
        if g_theta == np.inf:
            print("Couldn't find valid initial, failed")
            return None
        
        print(f"Initialized in {self.model.get_num_queries()} queries")
        
        dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
        self.update_log(dist, epsilon)

        # Begin Gradient Descent at the boundary
        timestart = time.time()
        xg, gg = best_theta, g_theta
        vg = torch.zeros_like(xg).cuda()
        learning_rate = start_learning_rate
        prev_obj = 100000
        distortions = [gg]
        for i in range(self.num_iterations):
            if svm == True:
                sign_gradient = self.sign_grad_svm(x0, y0, xg, initial_lbd=gg, h=beta)
            else:
                sign_gradient = self.sign_grad_v1(x0, y0, xg, initial_lbd=gg, h=beta)
            
            if False:
                # Compare cosine distance with numerical gradient.
                gradient = self.eval_grad(model, x0, y0, xg, initial_lbd=gg, tol=beta/500, h=0.01)
                print("    Numerical - Sign gradient cosine distance: ", 
                      scipy.spatial.distance.cosine(gradient.flatten(), sign_gradient.flatten()))
            
            # Line search
            min_theta = xg
            min_g2 = gg
            min_vg = vg
            for _ in tqdm(range(15)):
                if momentum > 0:
#                     # Nesterov
#                     vg_prev = vg
#                     new_vg = momentum*vg - alpha*sign_gradient
#                     new_theta = xg + vg*(1 + momentum) - vg_prev*momentum
                    new_vg = momentum*vg - alpha*sign_gradient
                    new_theta = xg + new_vg
                else:
                    new_theta = xg - alpha * sign_gradient
                    
                self.update_gradient(alpha * sign_gradient)
                
                new_theta /= torch.norm(new_theta)
                new_g2 = self.fine_grained_binary_search_local(
                    x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500)
                alpha = alpha * 2
                if new_g2 < min_g2:
                    min_theta = new_theta 
                    min_g2 = new_g2
                    if momentum > 0:
                        min_vg = new_vg
                else:
                    break
            
            if min_g2 >= gg:
                for _ in tqdm(range(15)):
                    alpha = alpha * 0.25
                    if momentum > 0:
#                         # Nesterov
#                         vg_prev = vg
#                         new_vg = momentum*vg - alpha*sign_gradient
#                         new_theta = xg + vg*(1 + momentum) - vg_prev*momentum
                        new_vg = momentum*vg - alpha*sign_gradient
                        new_theta = xg + new_vg
                    else:
                        new_theta = xg - alpha * sign_gradient
                    
                    self.update_gradient(alpha * sign_gradient)
                    
                    new_theta /= torch.norm(new_theta)
                    new_g2 = self.fine_grained_binary_search_local(
                        x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500)
                    if new_g2 < gg:
                        min_theta = new_theta 
                        min_g2 = new_g2
                        if momentum > 0:
                            min_vg = new_vg
                        break
            if alpha < 1e-4:
                alpha = 1.0
                print("Warning: not moving")
                beta = beta*0.1
                if (beta < 1e-8):
                    break
            
            xg, gg = min_theta, min_g2
            vg = min_vg
            
            distortions.append(gg)
            
            dist = self.compute_distance(x0, self.get_xadv(x0, gg, xg))
            print(f"Iteration {i+1} distortion {dist:.4f} num_queries {self.model.get_num_queries()}")

            self.update_log(dist, epsilon)
            
            if self.stopping_condition(query_limit, dist, epsilon):
                break
        
        timeend = time.time()
        dist = self.compute_distance(x0, self.get_xadv(x0, gg, xg))
        
        if self.search_succ(self.get_xadv(x0, gg, xg), y0, target):
            print(f"\nAdversarial Example Found Successfully: distortion {dist:.4f} pred {target if target else 'OK'} in "
                  f"{self.model.get_num_queries()} queries \nTime: {timeend-timestart:.4f} seconds")
        
        self.update_log(dist, epsilon)
        return self.get_xadv(x0, gg, xg)

    def sign_grad_v1(self, x0, y0, theta, initial_lbd, h=0.001, D=4, target=None):
        """
        Evaluate the sign of gradient by formulat
        sign(g) = 1/Q [ \sum_{q=1}^Q sign( g(theta+h*u_i) - g(theta) )u_i$ ]
        """
        K = self.k
        sign_grad = torch.zeros(theta.shape).cuda()
        ### USe orthogonal transform
        #dim = np.prod(sign_grad.shape)
        #H = np.random.randn(dim, K)
        #Q, R = qr(H, mode='economic')
        for iii in tqdm(range(K)):
#             # Code for reduced dimension gradient
#             u = np.random.randn(N_d,N_d)
#             u = u.repeat(D, axis=0).repeat(D, axis=1)
#             u /= LA.norm(u)
#             u = u.reshape([1,1,N,N])
            
            u = torch.randn(*theta.shape).cuda()
            #u = Q[:,iii].reshape(sign_grad.shape)
            u /= torch.norm(u)
            
            sign = 1
            new_theta = theta + h*u
            new_theta /= torch.norm(new_theta)
            
            model_query = self.model.predict_label(self.get_xadv(x0, initial_lbd, new_theta))
            
            if self.search_succ(self.get_xadv(x0, initial_lbd, new_theta), y0, target):
                sign = -1

            sign_grad += u*sign
        
        sign_grad /= K

#         sign_grad_u = sign_grad/LA.norm(sign_grad)
#         new_theta = theta + h*sign_grad_u
#         new_theta /= LA.norm(new_theta)
#         fxph, q1 = self.fine_grained_binary_search_local(self.model, x0, y0, new_theta, initial_lbd=initial_lbd, tol=h/500)
#         delta = (fxph - initial_lbd)/h
#         queries += q1
#         sign_grad *= 0.5*delta       
        
        return sign_grad
    
    def sign_grad_v2(self, x0, y0, theta, initial_lbd, h=0.001, K=200):
        """
        Evaluate the sign of gradient by formulat
        sign(g) = 1/Q [ \sum_{q=1}^Q sign( g(theta+h*u_i) - g(theta) )u_i$ ]
        """
        sign_grad = torch.zeros(theta.shape).cuda()
        for _ in range(K):
            u = torch.randn(*theta.shape).cuda()
            u /= torch.norm(u)
            
            ss = -1
            new_theta = theta + h*u
            new_theta /= torch.norm(new_theta)
            if self.model.predict_label(self.get_xadv(x0, initial_lbd, new_theta)) == y0:
                ss = 1
            sign_grad += sign(u)*ss
        sign_grad /= K
        return sign_grad


    def sign_grad_svm(self, x0, y0, theta, initial_lbd, h=0.001, K=100, lr=5.0, target=None):
        """
        Evaluate the sign of gradient by formulat
        sign(g) = 1/Q [ \sum_{q=1}^Q sign( g(theta+h*u_i) - g(theta) )u_i$ ]
        """
        sign_grad = torch.zeros(theta.shape).cuda()
        dim = torch.prod(theta.shape)
        X = torch.zeros((dim, K))
        for iii in range(K):
            u = torch.randn(*theta.shape).cuda()
            u /= torch.norm(u)
            
            sign = 1
            new_theta = theta + h*u
            new_theta /= torch.norm(new_theta)            
            
            if self.search_succ(self.get_xadv(x0, initial_lbd, new_theta), y0, target):
                sign = -1
                
            X[:,iii] = sign*u.reshape((dim,))
        
        Q = X.transpose().dot(X)
        q = -1*torch.ones((K,))
        G = torch.diag(-1*torch.ones((K,))).cuda()
        h = torch.zeros((K,)).cuda()
        ### Use quad_qp solver 
        #alpha = solve_qp(Q, q, G, h)
        ### Use coordinate descent solver written by myself, avoid non-positive definite cases
        alpha = quad_solver(Q, q)
        sign_grad = (X.dot(alpha)).reshape(theta.shape)
        
        return sign_grad

    def eval_grad(self, model, x0, y0, theta, initial_lbd, tol=1e-5,  h=0.001, sign=False, target=None):
        # print("Finding gradient")
        fx = initial_lbd # evaluate function value at original point
        grad = torch.zeros_like(theta)
        x = theta
        # iterate over all indexes in x
        it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
        
        while not it.finished:
            # evaluate function at x+h
            ix = it.multi_index
            oldval = x[ix]
            x[ix] = oldval + h # increment by h
            unit_x = x / torch.norm(x)
            if sign:
                if not self.search_succ(self.get_xadv(x0, initial_lbd, unit_x), y0, target):
                    g = 1
                else:
                    g = -1
                q1 = 1
            else:
                fxph, q1 = self.fine_grained_binary_search_local(x0, y0, unit_x, initial_lbd = initial_lbd, tol=h/500)
                g = (fxph - fx) / (h)
            
            # x[ix] = oldval - h
            # fxmh, q2 = self.fine_grained_binary_search_local(model, x0, y0, x, initial_lbd = initial_lbd, tol=h/500)
            x[ix] = oldval # restore

            # compute the partial derivative with centered formula
            grad[ix] = g
            it.iternext() # step to next dimension

        return grad

    def __call__(self, data, label, epsilon, target=None, query_limit=40000, seed=None, target_loader=None, 
                 svm=False, momentum=0.0):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        
        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     seed=seed, svm=svm, query_limit=query_limit, 
                                     target_loader=target_loader, momentum=momentum)
        return self.postprocess_result(adv)  
