import torch
from torch.autograd import grad
import torch.nn as nn

def soar(X, y, loss, clamp_value, step_size, model, device):
    X.requires_grad = True
    X_z = torch.zeros_like(X, requires_grad=True)
    _dim = X.shape[1] * X.shape[2] * X.shape[3]
    _batch = X.shape[0]
    
    dldx = len(X) * grad(loss, X, create_graph=True)[0].view(-1,_dim)
    
    z_d = torch.randn_like(X.detach(), requires_grad = True, device = device)
    z = torch.randn(_batch, requires_grad = True, device = device)
    
    z_d_norm = torch.norm(z_d.detach(), p = 2, dim = (1,2,3), keepdim = True).clamp(min = 1e-6)
    normalized_z_d = z_d.detach() / z_d_norm

    h = torch.ones([_batch,1,1,1], device = device)* step_size

    X_z.data = (X.detach() + (h * normalized_z_d).detach()).clamp(min = 0., max = 1.)
    
    yp_z = model(X_z)
    loss_z = nn.CrossEntropyLoss()(yp_z, y)
    dldx_z = len(X_z) * grad(loss_z, X_z, create_graph = True)[0].view(-1, _dim)
    Hz =  z_d_norm.view(-1,1) * (dldx_z - dldx)/h.view(-1,1).clamp(min = 1e-6)

    top = Hz.view(-1,_dim) + z.view(-1,1) * dldx
    bot = torch.matmul(dldx.view(-1,1,_dim), z_d.view(_batch, _dim, 1)).view(-1,1) + z.view(-1,1)
    
    reg = return_H_norm(top, bot, clamp_value)

    return reg

def return_H_norm(top, bot, clamp_value):
    H_tilde_z_tilde = torch.cat([top, bot], dim = 1)
    H_tilde_z_tilde_norm = torch.norm(H_tilde_z_tilde, p = 2, dim = 1)

    H_tilde_z_tilde_norm_clamp_max = H_tilde_z_tilde_norm.clamp(max = clamp_value)
    H_tilde_z_tilde_norm_clamp_mean = H_tilde_z_tilde_norm_clamp_max.mean()

    return H_tilde_z_tilde_norm_clamp_mean

