import time, torch
import numpy as np 
from numpy import linalg as LA


class Attack(object):
    def __init__(self, model, early_stopping=False, order=2, dataset=""):
        # self.MAX_ITER = 50000  # log limit
        self.num_iterations = 10000
        self.log = None
        self.model = model
        self.bounds = model.bounds
        self.early_stopping = early_stopping
        self.order = order
        self.success = False
        self.target = None
        self.label = None
        self.target_limit_break = 2000 if dataset == "Imagenet" else 100
        self.dataset = dataset
        self.trajectory = []
        self.gradients = []
        self.last_xadv = None
        self.last_gradient = None
        self.query_limit = None
        
    def _equal(self, v, z, dtype=torch.int):
        if type(v) is not torch.Tensor:
            v = torch.as_tensor(v, dtype=dtype)
        if type(z) is not torch.Tensor:
            z = torch.as_tensor(z, dtype=dtype)
            
        return v.cpu() == z.cpu()
    
    def _reset(self, query_limit):
        self.log = torch.zeros(query_limit, 2)
        self.query_limit = query_limit
        self.trajectory = []
        
    def manual_seed(self, seed):
        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            
    def get_log(self):
        return self.log
    
    def get_trajectory(self):
        return np.asarray(self.trajectory)
    
    def get_gradients(self):
        return np.asarray(self.gradients)
    
    def update_log(self, dist, epsilon):
        if self.model.get_num_queries() < self.query_limit:
            if self.log[self.model.get_num_queries(), 0] == 0:
                # represent initialization
                self.log[:, 0] = dist
            else:
                self.log[self.model.get_num_queries():, 0] = dist
            self.log[self.model.get_num_queries():, 1] = int(self.success and dist < epsilon)
        
        # self.log.append([dist, self.model.get_num_queries(), int(self.success)])
        self.trajectory.append([dist, self.model.get_num_queries(), self.last_xadv])
        self.gradients.append([self.target, self.model.get_num_queries(), self.last_gradient])
    
    def dim_attack_space(self, x, prod=False):
        # get attack space from x
        return np.prod(list(x.shape)) if prod else list(x.shape)
    
    def clip_image(self, x, lb=None, rb=None):
        
        if lb is None:
            lb = self.bounds[0]
        if rb is None:
            rb = self.bounds[1]
            
        if type(x) is np.ndarray:
            return np.minimum(np.maximum(lb, x), rb) 
        else:
            return torch.clamp(x, lb, rb)
        
    def get_xadv(self, x, d, v):
        # v, d can sometimes be numpy
        # x should always be a torch tensor except for rare cases (HSJA)
        
        if type(x) is torch.Tensor:
            x = x.cuda()
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()
            
        out = x + (d * v)
        return self.clip_image(out.float())
    
    def get_xorig(self, x):
        # Return original sample
        return self.clip_image(x)
    
    def to_attack_space(self, x):
        return self.clip_image(x)
    
    def compute_distance(self, x, xp):
        if type(x) is not torch.Tensor:
            x = torch.as_tensor(x, dtype=torch.float)
        if type(xp) is not torch.Tensor:
            xp = torch.as_tensor(xp, dtype=torch.float)
        
        return torch.norm(x.cuda() - xp.cuda(), self.order).cpu().item()
    
    def stopping_condition(self, query_limit, dist=np.inf, epsilon=0.):
        if self.early_stopping and (dist <= epsilon):
            return True

        if self.model.get_num_queries() >= query_limit:
            print('out of queries')
            return True
        
        return False
    
    def update_trajectory(self, dec, x, y, target):
        if type(dec) is torch.Tensor:
            dec = dec.cpu().numpy()
        if type(x) is torch.Tensor:
            x = x.detach().cpu().numpy()
        if type(y) is torch.Tensor:
            y = y.cpu().numpy()
        if target and type(target) is torch.Tensor:
            target = target.cpu().numpy()
        
        # We only we need to update if adversary has yet to win. 
        if not self.success:
            if target:
                self.success = True in list(dec == target)
            else:
                self.success = True in list(dec != y)
        
        self.last_xadv = x
        
    def update_gradient(self, gradient):
        if type(gradient) is torch.Tensor:
            gradient = gradient.detach().cpu().numpy()
            
        self.last_gradient = gradient
    
    def search_succ(self, x, y, target, expand=False, update=True):
        dec = self.model.predict_label(x).cuda()
        
        # Match types
        if target:
            target = torch.tensor(target, dtype=torch.int).cuda()
        if type(y) is not torch.Tensor:
            y = torch.tensor(y, dtype=torch.int)
        
        y = y.cuda()
        
        if update:
            self.update_trajectory(dec, x, y, target)
        
        if expand:
            # Compare as numpy will produce expanded dim
            dec = dec.cpu().numpy()
            target = target.cpu().numpy() if target else target
            y = y.cpu().numpy()
        
        if target:
            return dec == target
        else:
            return dec != y
        
    def benign_succ(self, x, y, target, expand=False):
        # Check if benign sample is misclassifying without updating log. 
        return self.search_succ(x, y, target, expand, update=False)
        
    def postprocess_result(self, result):
        if type(result) is np.ndarray:
            return torch.tensor(result, dtype=torch.float)
        else:
            return result  
        
    def validate_args(self, data, label, epsilon, target, target_loader, query_limit):
        assert len(data) == 1, "Only accept single sample for this implementation."
        self._reset(query_limit)
        
        if target:
            assert target_loader is not None, "For targeted variant, need target_loader to be set!"
        

def project(tensor, epsilon=1, ord=2):
    """
    Compute the orthogonal projection of the input tensor (as vector) onto the L_ord epsilon-ball.
    **Assumes the first dimension to be batch dimension, which is preserved.**
    :param tensor: variable or tensor
    :type tensor: torch.autograd.Variable or torch.Tensor
    :param epsilon: radius of ball.
    :type epsilon: float
    :param ord: order of norm
    :type ord: int
    :return: projected vector
    :rtype: torch.autograd.Variable or torch.Tensor
    """

    assert isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.autograd.Variable), 'given tensor should be torch.Tensor or torch.autograd.Variable'

    if ord == 2:
        size = tensor.size()
        flattened_size = np.prod(np.array(size[1:]))

        tensor = tensor.view(-1, flattened_size)
        clamped = torch.clamp(epsilon/torch.norm(tensor, 2, dim=1), max=1)
        clamped = clamped.view(-1, 1)

        tensor = tensor * clamped
        if len(size) == 4:
            tensor = tensor.view(-1, size[1], size[2], size[3])
        elif len(size) == 2:
            tensor = tensor.view(-1, size[1])
    elif ord == float('inf'):
        tensor = torch.clamp(tensor, min=-epsilon, max=epsilon)
    else:
        raise NotImplementedError()

    return tensor



class HLM_Attack(Attack):
    def __init__(self, enc, dec):
        self.enc = enc
        self.dec = dec
        self.class_conditional = None  # Updates with every call
        
    def dim_attack_space(self, x, prod=False):
        # get attack space from x
        z = self.to_attack_space(x)
        return np.prod(list(z.shape)) if prod else list(z.shape)
    
    def E(self, x):
        if type(x) is torch.Tensor:
            x = x.float().cuda()
        else:
            x = torch.tensor(x).float().cuda()
        
        with torch.no_grad():
            return self.enc(x, self.class_conditional).detach()

    def G(self, z):
        if type(z) is torch.Tensor:
            z = z.float().cuda()
        else:
            z = torch.tensor(z).float().cuda()

        with torch.no_grad():
            return self.dec(z, self.class_conditional).detach()

    def update_gradient(self, gradient):
        gradient = self.G(gradient)
        
        if type(gradient) is torch.Tensor:
            gradient = gradient.detach().cpu().numpy()
            
        self.last_gradient = gradient
        
    def get_xadv(self, x, d, v):
        z = self.E(x).cuda()
        # z will always be a tensor
        
        if type(x) is torch.Tensor:
            x = x.cuda()
        else:
            x = torch.tensor(x).cuda()
            
        if type(x) is torch.Tensor:
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()

        # out = self.G(z + 0.6*torch.clamp(d * v, 0., 1.))
        # dv = torch.clamp(d * v, -0.3, 0.3)
        # out = self.G(z + dv)
        # v_in = (v + z.mean()) / z.std()
        # theta = self.G((1 - d)*z + d*v)
        
        theta = self.G(z + v)
        
        # theta_affected = theta < d
        # xd = x.clone()
        # x.cpu()
        # xd[theta_affected] = (1 - d)*xd[theta_affected] + d*theta[theta_affected]
        
        # probz = torch.rand_like(theta)
        # theta[probz < 0.1] = 0 
        # theta[probz >= 0.1] *= 10
        # theta = project(theta, 3.0)
        # return self.clip_image(xd).float()
        
        # upper_bound = 1.0 - x # convert from 0.5 - timg to 1 - timg, because our pixel range is (0,1)
        # lower_bound = 0.0 - x # convert from -0.5 to 0.0
        # cond1 = torch.gt(theta, upper_bound).float()
        # cond2 = torch.le(theta, upper_bound).float()
        # cond3 = torch.gt(theta, lower_bound).float()
        # cond4 = torch.le(theta, lower_bound).float()

        # theta = torch.mul(cond1, upper_bound) + \
        #         torch.mul(torch.mul(cond2, cond3), theta) + \
        #         torch.mul(cond4, lower_bound)
        out = (1 - d)*x + (d*theta)
        return self.clip_image(out.float())
        # return self.clip_image(x + d*theta).float()
    
    def get_xorig(self, x):
        # encode and decode to get starting point.
        # z = self.E(x)
        # return self.clip_image(self.G(z))
        return self.clip_image(x)
    
    def to_attack_space(self, x):
        z = self.E(x)
        return z.cuda()
    

class Sampling_Attack(HLM_Attack):
    def get_xadv(self, x, d, v):
        # z = self.E(x).cuda()
        # z will always be a tensor
        
        if type(x) is torch.Tensor:
            x = x.cuda()
        else:
            x = torch.tensor(x).cuda()
            
        if type(x) is torch.Tensor:
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()

        # out = self.G(z + 0.6*torch.clamp(d * v, 0., 1.))
        # dv = torch.clamp(d * v, -0.3, 0.3)
        # out = self.G(z + dv)
        # v_in = (v + z.mean()) / z.std()
        theta = self.G(v) 
        # probz = torch.rand_like(theta)
        # theta[probz < 0.1] = 0 
        # theta[probz >= 0.1] *= 10
        # theta = project(theta, 3.0)
        out = x + (d*theta)
        return self.clip_image(out.float())
    
    def update_gradient(self, gradient):
        gradient = self.G(gradient)
        
        if type(gradient) is torch.Tensor:
            gradient = gradient.detach().cpu().numpy()
            
        self.last_gradient = gradient

        
class RandSampling_Attack(HLM_Attack):
    def dim_attack_space(self, x, prod=False):
        shape = [3, self.enc.model.resize_dim, self.enc.model.resize_dim]
        return np.prod(shape) if prod else shape
    
    def E(self, x):
        if type(x) is torch.Tensor:
            x = x.float().cuda()
        else:
            x = torch.tensor(x).float().cuda()
        
        with torch.no_grad():
            res = self.enc(x, self.class_conditional).detach()
            self.dec.model.update_original(x)
            self.dec.model.update_coordinates(self.enc.model.coordinates)
            return res

    def G(self, z):
        if type(z) is torch.Tensor:
            z = z.float().cuda()
        else:
            z = torch.tensor(z).float().cuda()

        with torch.no_grad():
            return self.dec(z, self.class_conditional).detach()
        
    def get_xadv(self, x, d, v):        
        if type(x) is torch.Tensor:
            x = x.cuda()
        else:
            x = torch.tensor(x).cuda()
            
        if type(x) is torch.Tensor:
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()

        theta = self.G(v) 
        out = x + (d*theta)
        return self.clip_image(out.float())
    
    def update_gradient(self, gradient):
        gradient = self.G(gradient)
        
        if type(gradient) is torch.Tensor:
            gradient = gradient.detach().cpu().numpy()
            
        self.last_gradient = gradient
