import time, torch
import numpy as np 
import torch.nn as nn
from numpy import linalg as LA

from attack import OPT_attack
from attack import Sampling_Attack


class Sampling_OPT_attack(OPT_attack, Sampling_Attack):
    def __init__(self, model, enc, dec, dataset="", order=2, early_stopping=False):
        OPT_attack.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping)
        Sampling_Attack.__init__(self, enc, dec)

    def __call__(self, data, label, epsilon, target=None, target_loader=None, query_limit=40000, seed=None):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item()) if target is None else int(target.item())
        
        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     query_limit=query_limit, target_loader=target_loader)
        return self.postprocess_result(adv)  
