from Libraries import *
eps=sys.float_info.epsilon
from time import time
from tensorboardX import SummaryWriter
class net_mnist(nn.Module):
    def __init__(self,input_size,output_size):
        super(net_mnist, self).__init__()
        self.fc1 = nn.Linear(input_size, 4096,bias = True)
        self.fc2 = nn.Linear(4096,4096,bias = True)
        self.fc3 = nn.Linear(4096,4096,bias = True)
        self.fc4 = nn.Linear(4096,4096,bias = True)
        self.fc5 = nn.Linear(4096,output_size,bias = True)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.reshape(x.shape[0],1,-1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.fc5(x)
        return x.squeeze()
def train_mnist(net,loader,val_loader,criterion,lrate = 0.01,max_epochs=6):
    optimizer = torch.optim.SGD(net.parameters(),lr=lrate,momentum=0.9)
    count = 0
    for epoch in range(max_epochs):
        net.train()
        tic = time()
        for inp, target in loader:
            count += 1
            inp,target = inp.to(device),target.to(device)
            optimizer.zero_grad()
            out = net(inp)
            loss = criterion(out,target)
            loss.backward()
            optimizer.step()
            lrate = lrate/10 if count % 100 == 0 else lrate
            if count % 10 == 0:
                toc = time() - tic
                temp = validate(net,loader)
                accuracy = validate(net, val_loader)
                writer.add_scalar('my_scalar',accuracy,count)
                writer.add_scalar('Training' ,temp ,count)
                print(f'Epoch #{epoch}: Accuracy = {accuracy * 100:.2f}% [{toc:.2f} seconds] ','count = ',count,'Training Accuracy: ',100*temp)
                torch.save(net.state_dict(),'./mnist_model.pth')
                print('.........Model has been saved............')

    writer.close()
    return net

def validate(net, val_loader):
    net.eval()
    num_correct = 0
    num_examples = 0
    for inp, target in val_loader:
        out = net(inp.to(device)).argmax(1)
        correct = torch.eq(out, target.to(device))
        num_correct += torch.sum(correct).item()
        num_examples += correct.shape[0]
    net.train()
    return num_correct/num_examples

def str_atk_tg(net,x,target,criterion,eps1,eps2,rho_0,eta,lamda,sigma = 1e-2,iterations = 100):
    model_v2= deepcopy(net)
    model   = deepcopy(net).to(device)
    A       = model.state_dict()['fc1.weight'].data
    B       = model.state_dict()['fc2.weight'].data
    x,target = x.to(device),target.to(device)
    
    #Initializations: Try another initialization if this did not work
    delta_a = torch.zeros_like(A).to(device) #Layer perturbation
    delta   = torch.zeros_like(x).to(device) #Input perturbation
    w       = torch.zeros_like(x).to(device) #Input perturbation
    z       = torch.zeros_like(x).to(device) #Input perturbation
    u       = torch.zeros_like(x).to(device) #First  dual variable
    v       = torch.zeros_like(x).to(device) #Second dual variable
    sum_b   = torch.zeros_like(torch.eye(B.shape[1])).to(device) #Summation of diag(B)^2
    rho     = 1.01*rho_0
    
    lr      = 0.1
    eq_con  = []

    #sum of diag(B)**2
    for j in range(B.shape[0]):
        sum_b += torch.diag(B[j,:])**2
    
    #The attack will start here
    for i in range(iterations):
        
        if (i+1)%50 == 0:
            lamda = lamda * 1.01
            rho     = 1.01*rho_0

        
        #Updating delta
        temp1   = 2*lamda*torch.mm(A.t(),A) + (2+rho)*torch.eye(A.shape[1]).to(device) #This is the A in Ax = B
        temp2   = 2*lamda*torch.mm(torch.mm(A.t(),delta_a),x.squeeze().reshape(-1,1))
        temp2   = temp2.reshape(z.shape) + rho*z.squeeze() - u.squeeze() # This is the    B in Ax = B
        delta,_ = torch.solve(temp2.reshape(-1,1),temp1)
        delta   = delta.reshape(z.shape)
        
        #Updating w
        w    = z - v/rho #This is originally the a in the overleaf file, but it can be implemented this way.
        min1 = torch.min(1-x,torch.tensor(eps1).to(device))
        max1 = torch.max(-x,torch.tensor(-eps1).to(device))
        ind1 = (w > min1).to(device)
        ind2 = (w < max1).to(device)

        w[ind1] = min1[ind1]
        w[ind2] = max1[ind2]
        #The rest of indicies will not change from the original a

        #updating z                
        perturbed_input = Variable(x + z,requires_grad = True).to(device)
        output = model(perturbed_input).unsqueeze(dim=0)
        loss   = criterion(output,target)
        model.zero_grad()
        #grad_f = torch.autograd.grad(loss,perturbed_input) #The gradient of f with respect to the purterbed input
        
        #This A_grad will be used for updating Delta A
        #model.zero_grad()
        loss.backward()
        A_grad = model.fc1.weight.grad.data
        grad_f = perturbed_input.grad.data
        
        b = delta + u/rho
        c = w     + v/rho
        z = (eta*z + rho*(b+c) - grad_f[0])/(eta + 2*rho)
        
        #Updating delta_a
        temp    = torch.mm(delta_a, - 2*lamda*torch.mm(x.squeeze().reshape(-1,1),x.squeeze().reshape(-1,1).t()))
        temp    = temp - 2*lamda* torch.mm(torch.mm(A,delta.squeeze().reshape(-1,1)),x.squeeze().reshape(-1,1).t())
        temp    = temp - 2*torch.mm(sum_b,delta_a) - A_grad
        delta_a = prox_infty(delta_a + lr*temp,eps2)
        
        #Updating u,v
        u += rho*(delta - z)
        v += rho*(w     - z)
        
        #Check whether the classifier is fooled
        out1 = model(x + delta).unsqueeze(dim=0).argmax(1)
        out2 = model(x + z    ).unsqueeze(dim=0).argmax(1)
        out3 = model(x + w    ).unsqueeze(dim=0).argmax(1)
        if out1 == target or out2 == target or out3 == target:
            rho   = rho*1.1
            eta   = eta*1.01
            #print('gg')
            #Sparsity
            delta[torch.abs(delta) < sigma] = 0
            z[torch.abs(z) < sigma]         = 0
            w[torch.abs(w) < sigma]         = 0

        model_v2.state_dict()['fc1.weight'][:,:] = deepcopy(A+delta_a)

        out0 = model_v2(x     ).unsqueeze(dim=0).argmax(1)
        out1 = model(x + delta).unsqueeze(dim=0).argmax(1)
        out2 = model(x + z    ).unsqueeze(dim=0).argmax(1)
        out3 = model(x + w    ).unsqueeze(dim=0).argmax(1)    
        
        #Checking the linear equality, what is happenning to it
        if i > 0:
            err = torch.norm(torch.mm(A,w.squeeze().reshape(-1,1)) - torch.mm(delta_a,x.squeeze().reshape(-1,1)))
            eq_con.append(err.item())

        if out1 == target and out2 == target and out3 == target and out0 == target:
            print('GG')
            #print(torch.max(delta_x),torch.max(z),torch.max(w),eps1)
            cond1 = round(torch.max(torch.abs(delta)).item(),3) <= round(eps1,3)
            cond2 = round(torch.max(torch.abs(w)).item(),3) <= round(eps1,3)
            cond3 = round(torch.max(torch.abs(z)).item(),3) <= round(eps1,3)
            if cond1 and cond2 and cond3:
                print(f'All conditions are satisfied, and the attack is successful in {i+1} iterations')
                return delta,z,w,delta_a,eq_con
    print('The algorithm did not succeed in attacking the model')
    return delta,z,w,delta_a,eq_con


def prox_infty(A,eps = 1):
    A = A.cpu().numpy()
    for i in range(A.shape[0]): # Projecting each row of A
        A[i,:] = poject_on_l1_ball(A[i,:],eps)
    return torch.from_numpy(A).to(device)


def euclidean_proj_simplex(v, s=1):
    assert s > 0,  "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  
    if v.sum() == s and np.alltrue(v >= 0):
        return v
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
    theta = (cssv[rho] - s) / (rho + 1.0)
    w = (v - theta).clip(min=0)
    return w


def poject_on_l1_ball(v, s=1):
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  
    u = np.abs(v)
    if u.sum() <= s:
        return v
    w = euclidean_proj_simplex(u, s=s)
    w *= np.sign(v)
    return w