import torch
import torch.nn as nn
import torch.nn.functional as F

class Softplus(nn.Module):
    def __init__(self):
        super(Softplus, self).__init__()
        self.hessian_bound = 0.25
            
    def forward(self, x):
        pos_x = (x >= 0)
        abs_x = torch.abs(x)
        eval_neg_x = F.softplus(-abs_x)

        eval_x = pos_x * (x + eval_neg_x) + (~ pos_x) * eval_neg_x
        return eval_x
    
    def gradient(self, x):
        return torch.sigmoid(x)
    
    def hessian(self, x):
        sigm_x = torch.sigmoid(x)
        return sigm_x * (1. - sigm_x)

class CRC_Diag(nn.Module):
    def __init__(self, num_features, num_classes, w1_bound=2., eps=1e-6):
        super(CRC_Diag, self).__init__()
        if w1_bound is None:
            w1_bound = np.float('inf')
        self.w1_bound = w1_bound
        
        self.w1 = nn.Parameter(-1 + 2*torch.rand(1, num_features), requires_grad=True)
        self.b1 = nn.Parameter(torch.zeros(1, num_features), requires_grad=True)
        
        self.activation = Softplus()
        self.linear2 = nn.Linear(num_features, num_classes)
        
        self.eps = torch.tensor(eps)
        
    def gradient(self, x, y, t):
        x = torch.flatten(x, start_dim=1)
        w1 = torch.clamp(self.w1, min=-self.w1_bound, max=self.w1_bound)
        
        x = (w1 * x) + self.b1
        act_grad = self.activation.gradient(x)
        
        grad = w1
        grad = grad * act_grad
        
        w2 = self.linear2.weight
        grad = grad * (w2[y, :] - w2[t, :])
        return grad
    
    def hessian(self, x, y, t):
        x = torch.flatten(x, start_dim=1)
        w1 = torch.clamp(self.w1, min=-self.w1_bound, max=self.w1_bound)
        
        x = (w1 * x) + self.b1
        act_hess = self.activation.hessian(x)
        
        hess = act_hess * (w1 * w1)
        w2 = self.linear2.weight
        hess = hess * (w2[y, :] - w2[t, :])
        return hess


        
    def forward(self, features):
        x = torch.flatten(features, start_dim=1)
        w1 = torch.clamp(self.w1, min=-self.w1_bound, max=self.w1_bound)
        
        x = (w1 * x) + self.b1
        x = self.activation(x)

        logits = self.linear2(x)
        return logits
    
n_features = 1024
n_classes = 100
b_size = 128
        
m = CRC_Diag(n_features, n_classes)
x = torch.randn(b_size, n_features)
y = torch.randint(0, n_classes, (b_size,))
logits = m(x)


t = (y + 1) % n_classes
w_diff = m.linear2.weight[y, :] - m.linear2.weight[t, :]
# print(w_diff.shape)

lbs, _ = torch.min(w_diff * 0.25 * m.w1 * m.w1, dim=1)
ubs, _ = torch.max(w_diff * 0.25 * m.w1 * m.w1, dim=1)




x_n = torch.clone(x)



grad = m.gradient(x, y, t)
grad_d = (grad)
grad_norm = torch.norm(grad_d.view(-1, n_features), dim=1)
# print('step0', grad_d.shape, grad_norm.shape, grad_norm.detach().min().item(), 
#       grad_norm.detach().mean().item(), 
#       grad_norm.detach().max().item())






delta = x.new_empty((b_size, n_features)).fill_(0.)

logits = m(x + delta)
logits_diff = logits[torch.arange(b_size), y] - logits[torch.arange(b_size), t]
print('init', logits_diff.shape, logits_diff.detach().min().item(), logits_diff.detach().mean().item(),
      logits_diff.detach().max().item())


eta_min = (logits_diff < 0) * (-torch.reciprocal(ubs))
eta_max = (logits_diff > 0) * (-torch.reciprocal(lbs))
eta = 0.5 * (eta_min + eta_max)

print(eta_min.shape, eta_min.detach().min().item(), eta_min.detach().mean().item(), eta_min.detach().max().item())
print(eta_max.shape, eta_max.detach().min().item(), eta_max.detach().mean().item(), eta_max.detach().max().item())


for i in range(10):
    eta = (eta_min + eta_max)/2.
    eta_m = eta[:, None]
    for j in range(5):
        grad = m.gradient(x + delta, y, t)
        hess = m.hessian(x + delta, y, t)
#         print(grad.shape, hess.shape)
        
        d = - torch.reciprocal(1 + (eta_m * hess)) * ((eta_m * grad) + delta)
        grad_d = (d * (1 + (eta_m * hess))) + ((eta_m * grad) + delta)
        
        grad_norm = torch.norm(grad_d.view(-1, n_features), dim=1)
#         print('step1', i, j, grad_d.shape, grad_norm.shape, grad_norm.detach().min().item(), 
#               grad_norm.detach().mean().item(), 
#               grad_norm.detach().max().item())

        delta = delta + d
        
#         x_n = x + delta
#         print(logits_diff.shape, logits_diff.detach().min().item(), logits_diff.detach().mean().item(),
#               logits_diff.detach().max().item())


        grad = m.gradient(x + delta, y, t)
        grad_d = delta + (eta_m * grad)
        grad_norm = torch.norm(grad_d.view(-1, n_features), dim=1)
#         print('step2', i, j, grad_d.shape, grad_norm.shape, grad_norm.detach().min().item(), 
#               grad_norm.detach().mean().item(), 
#               grad_norm.detach().max().item())
        

    logits = m(x + delta)
    logits_diff = logits[torch.arange(b_size), y] - logits[torch.arange(b_size), t]

    ge_indicator = (logits_diff > 0)
    eta_min[ge_indicator] = eta[ge_indicator]
    eta_max[~ge_indicator] = eta[~ge_indicator]

    print('i=' + str(i), logits_diff.shape, logits_diff.detach().min().item(), logits_diff.detach().mean().item(),
          logits_diff.detach().max().item())


#     if i==1:
#         quit()
        