import numpy as np
import torch

from attack import HSJA, HLM_Attack


class HLM_HSJA(HSJA, HLM_Attack):
    def __init__(self, model, enc, dec, order=2, dataset="", early_stopping=False, gamma=1.0):
        HSJA.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping, gamma=gamma)
        HLM_Attack.__init__(self, enc, dec)
        self.x_orig = None
        
    def get_xadv(self, x):
        # HSJA implementation proceeds over the adversarial attempt directly, instead of the noise. 
        
        # out = self.G(z)
        return self.clip_image(x)
    
    def to_image_space(self, v):
        if type(v) is not torch.Tensor:
            v = torch.tensor(v, dtype=torch.float).cuda()
        x = torch.tensor(self.x_orig).cuda()
        z = self.E(x)
        theta = self.G(v.view(-1, *z.shape[1:]))
        # theta = self.G(v.view(-1, *z.shape[1:])) - x
        return self.clip_image(theta)
    
    def to_attack_space(self, x):
        z = self.E(x)
        return z.cpu().numpy()
    
    def blend(self, a, x, b, v):
        # ignore a
        if type(v) is not torch.Tensor:
            v = torch.tensor(v, dtype=torch.float).cuda()
        if type(x) is not torch.Tensor:
            x = torch.tensor(x, dtype=torch.float).cuda()
        # if type(a) is not torch.Tensor:
        #    a = torch.tensor(a, dtype=torch.float).cuda()
        if type(b) is not torch.Tensor:
            b = torch.tensor(b, dtype=torch.float).cuda()
        
        
        z = self.E(x)
        # print('z', z.shape)
        # print('v', v.shape)
        # z[-1] + 
        xp = self.G(v.view(-1, *z.shape[1:])).cuda()
        out = (a * x) + (b * xp)
        return self.clip_image(out)
    
    def __call__(self, data, label, epsilon, target=None, target_loader=None, query_limit=20000, seed=None):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item())
        
        data = data.cpu().numpy()
        label = label.cpu().numpy()
        self.x_orig = data
        
        if target:
            target = target.cpu().numpy()
        adv = self.hsja(data, label, epsilon, target, target_loader, query_limit, seed)
        
        return self.postprocess_result(adv)