import time
import numpy as np
from numpy import linalg as LA
import torch
import scipy.spatial
from scipy.linalg import qr
import random
import torch.nn as nn

from tqdm import tqdm

from attack import OPT_attack_sign_SGD, Sampling_Attack
from attack.Sign_OPT import quad_solver, sign

start_learning_rate = 1.0


class Sampling_OPT_attack_sign_SGD(OPT_attack_sign_SGD, Sampling_Attack):
    def __init__(self, model, enc, dec, order=2, k=200, dataset="", early_stopping=False):
        OPT_attack_sign_SGD.__init__(self, model, order=order, k=k, dataset=dataset, early_stopping=early_stopping)
        Sampling_Attack.__init__(self, enc, dec)

    def __call__(self, data, label, epsilon, target=None, query_limit=40000, seed=None, target_loader=None, 
                 svm=False, momentum=0.0):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item())
        
        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     seed=seed, svm=svm, query_limit=query_limit, 
                                     target_loader=target_loader, momentum=momentum)
        return self.postprocess_result(adv)  
