import numpy as np
import torch

from attack import HSJA, RandSampling_Attack


class RandSampling_HSJA(HSJA, RandSampling_Attack):
    def __init__(self, model, enc, dec, order=2, dataset="", early_stopping=False, gamma=1.0):
        HSJA.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping, gamma=gamma)
        RandSampling_Attack.__init__(self, enc, dec)
        self.x_orig = None
    
    def to_attack_space(self, x):
        z = self.E(x)
        return z.cpu().numpy()
    
    def to_image_space(self, v):
        return self.clip_image(self.G(v))
    
    def get_xadv(self, z):
        # HSJA implementation proceeds over the adversarial attempt directly, instead of the noise. 
        
        # out = self.G(z)
        return self.clip_image(z)
    
    def blend(self, a, x, b, v):
        if type(v) is not torch.Tensor:
            v = torch.tensor(v, dtype=torch.float).cuda()
        if type(x) is not torch.Tensor:
            x = torch.tensor(x, dtype=torch.float).cuda()
        if type(a) is not torch.Tensor:
            a = torch.tensor(a, dtype=torch.float).cuda()
        if type(b) is not torch.Tensor:
            b = torch.tensor(b, dtype=torch.float).cuda()
        
        z = self.E(x)
        self.dec.model.update_original(x)
        self.dec.model.update_coordinates(self.enc.model.coordinates)
        vp = self.G(v.view(-1, *z.shape[1:])).cuda()
        out = (a * x) + (b * vp)
        return self.clip_image(out)

    def __call__(self, data, label, epsilon, target=None, target_loader=None, query_limit=20000, seed=None):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item())
        
        data = data.cpu().numpy()
        label = label.cpu().numpy()
        self.x_orig = data
        
        if target:
            target = target.cpu().numpy()
        adv = self.hsja(data, label, epsilon, target, target_loader, query_limit, seed)
        
        return self.postprocess_result(adv)