import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from attack import RayS
from attack import HLM_Attack


class HLM_RayS(RayS, HLM_Attack):
    def __init__(self, model, enc, dec, order=2, dataset="", early_stopping=False):
        RayS.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping)
        HLM_Attack.__init__(self, enc, dec)
        self.model = model
        self.order = order
        self.z_final = None

    def __call__(self, data, label, epsilon=0.3, target=None, target_loader=None, seed=None, query_limit=10000):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item()) if target is None else int(target.item())
        
        adv = self.attack_hard_label(data, label, epsilon, target=target, seed=seed, query_limit=query_limit)
        return self.postprocess_result(adv)
