import time, torch
import numpy as np 
import torch.nn as nn
from numpy import linalg as LA

from attack import OPT_attack
from attack import RandSampling_Attack


class RandSampling_OPT_attack(OPT_attack, RandSampling_Attack):
    def __init__(self, model, enc, dec, dataset="", order=2, early_stopping=False):
        OPT_attack.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping)
        RandSampling_Attack.__init__(self, enc, dec)

    def E(self, x):
        raise ValueError("E() not used for two-point estimate Sign-OPT")
        
    def get_xadv(self, x, d, v):        
        if type(x) is torch.Tensor:
            x = x.cuda()
        else:
            x = torch.tensor(x).cuda()
            
        if type(x) is torch.Tensor:
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()
        
        theta = self.G(v) 
        out = x + (d*theta)
        return self.clip_image(out.float())
    
    def __call__(self, data, label, epsilon, target=None, target_loader=None, query_limit=40000, seed=None):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item()) if target is None else int(target.item())
        
        # Update here since E() is never called
        enc = self.enc.model
        self.dec.model.update_original(torch.zeros(data.shape).cuda())
        coords = np.random.choice(np.arange(enc.orig_dim), 2*enc.resize_dim*enc.resize_dim)
        coords = coords.reshape((-1, 2))
        self.dec.model.update_coordinates(coords)
        
        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     query_limit=query_limit, target_loader=target_loader)
        return self.postprocess_result(adv)  
