import time
import numpy as np
from numpy import linalg as LA
import torch
import scipy.spatial
from scipy.linalg import qr
import random
import torch.nn as nn

from tqdm import tqdm

from attack import OPT_attack_sign_SGD, HLM_Attack
from attack.Sign_OPT import quad_solver, sign

start_learning_rate = 1.0


class HLM_OPT_attack_sign_SGD(OPT_attack_sign_SGD, HLM_Attack):
    def __init__(self, model, enc, dec, order=2, k=200, dataset="", early_stopping=False):
        OPT_attack_sign_SGD.__init__(self, model, order=order, k=k, dataset=dataset, early_stopping=early_stopping)
        HLM_Attack.__init__(self, enc, dec)

    def __call__(self, data, label, epsilon, target=None, query_limit=40000, seed=None, target_loader=None, 
                 svm=False, momentum=0.0):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item())
        
        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     seed=seed, svm=svm, query_limit=query_limit, 
                                     target_loader=target_loader, momentum=momentum)
        return self.postprocess_result(adv)  
