import time
import numpy as np
from numpy import linalg as LA
import torch
import scipy.spatial
from scipy.linalg import qr
import random
import torch.nn as nn

from tqdm import tqdm

from attack import OPT_attack_sign_SGD, RandSampling_Attack
from attack.Sign_OPT import quad_solver, sign

start_learning_rate = 1.0


class RandSampling_OPT_attack_sign_SGD(OPT_attack_sign_SGD, RandSampling_Attack):
    def __init__(self, model, enc, dec, order=2, k=200, dataset="", early_stopping=False):
        OPT_attack_sign_SGD.__init__(self, model, order=order, k=k, dataset=dataset, early_stopping=early_stopping)
        RandSampling_Attack.__init__(self, enc, dec)
    
    def E(self, x):
        raise ValueError("E() not used for two-point estimate Sign-OPT")
        
    def get_xadv(self, x, d, v):        
        if type(x) is torch.Tensor:
            x = x.cuda()
        else:
            x = torch.tensor(x).cuda()
            
        if type(x) is torch.Tensor:
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()
        
        theta = self.G(v) 
        out = x + (d*theta)
        return self.clip_image(out.float())

    def __call__(self, data, label, epsilon, target=None, query_limit=40000, seed=None, target_loader=None, 
                 svm=False, momentum=0.0):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item())
        
        # Update here since E() is never called
        enc = self.enc.model
        self.dec.model.update_original(torch.zeros(data.shape).cuda())
        coords = np.random.choice(np.arange(enc.orig_dim), 2*enc.resize_dim*enc.resize_dim)
        coords = coords.reshape((-1, 2))
        self.dec.model.update_coordinates(coords)
        
        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     seed=seed, svm=svm, query_limit=query_limit, 
                                     target_loader=target_loader, momentum=momentum)
        return self.postprocess_result(adv)  
