"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

import torch
import torch.nn.functional as F

def gauss2_tensor(X, Y, mu, sigma, normalize=True):
    sigma = sigma.float() 
    X = X.float()
    Y = Y.float()
    mu = mu.float()
    
    D = sigma[0, 0]*sigma[1, 1] - sigma[0, 1]*sigma[1, 0]
    B = torch.inverse(sigma)
    X = X - mu[0]
    Y = Y - mu[1]
    Z = B[0, 0]*X**2 + B[0, 1]*X*Y + B[1, 0]*X*Y + B[1, 1]*Y**2
    Z = torch.exp(-0.5*Z)

    if normalize:
        Z *= (2.*torch.pi*torch.sqrt(D))**(-1.)
    return Z


def smooth2d_tensor(Z, span=10, device='cpu'):

    x = torch.linspace(-2.*span, 2.*span, 2*span + 1, device=device, dtype=torch.float32)
    y = torch.linspace(-2.*span, 2.*span, 2*span + 1, device=device, dtype=torch.float32)
    X, Y = torch.meshgrid(x, y)
    mu = torch.tensor([0., 0.], device=device, dtype=torch.float32)
    sigma = torch.diag(torch.tensor([float(span), float(span)], device=device, dtype=torch.float32))**2
    kernel = gauss2_tensor(X, Y, mu, sigma)
    kernel = kernel/kernel.sum()
    
    kernel = kernel.unsqueeze(0).unsqueeze(0).to(Z.dtype)
    Z = Z.unsqueeze(0).unsqueeze(0)
    padding = kernel.shape[-1]//2

    return F.conv2d(input=Z, weight=kernel, padding=padding).squeeze()

def grad_taper_tensor(nz, nx, tapersize=20, thred=0.05, marine_or_land='marine', device='cpu'):
    if marine_or_land in ['marine', 'Offshore']: 
        taper = torch.ones((nz, nx), device=device, dtype=torch.float32)
        taper[:tapersize,:] = 0.0
            
    else:
        H = torch.hamming_window(tapersize*2, device=device, dtype=torch.float32)
        H = H[tapersize:]
        taper = torch.zeros((nz, nx), device=device, dtype=torch.float32)
        for ix in range(nz):
            taper[ix, :tapersize] = H
        
        taper = smooth2d_tensor(taper, span=tapersize//2, device=device)
        taper /= taper.max()
        taper *= (1 - thred)
        taper = -taper + 1
        taper = taper * taper
    
    return taper


class GradProcessor():
    def __init__(self,
                 grad_mute=0,
                 grad_smooth=0,
                 grad_mask=None,
                 norm_grad=True,
                 forw_illumination=True,
                 marine_or_land="land",
                 depth_weight=False,shmode=False,src_z=None,src_x=None):
        self.grad_mute = grad_mute
        self.grad_smooth = grad_smooth   
        self.grad_mask = grad_mask
        self.marine_or_land = marine_or_land
        self.norm_grad = norm_grad
        self.forw_illumination = forw_illumination
        self.depth_weight = depth_weight
        self.shmode = shmode
        self.src_z = torch.tensor(src_z) if src_z is not None else None   
        self.src_x = torch.tensor(src_x) if src_x is not None else None   

    def forward(self, nx, nz, vmax, grad, forw=None):
        device = grad.device
        dtype = grad.dtype
        
        if not isinstance(grad, torch.Tensor):
            grad = torch.tensor(grad, device=device, dtype=dtype)
        if forw is not None and not isinstance(forw, torch.Tensor):
            forw = torch.tensor(forw, device=device, dtype=dtype)
        if not isinstance(vmax, torch.Tensor):
            vmax = torch.tensor(vmax, device=device, dtype=dtype)
            
        # tapper mask
        if self.grad_mute > 0:
            if self.marine_or_land.lower() in ['marine', 'offshore']:
                grad_thred = 0.0
            elif self.marine_or_land.lower() in ['land', 'onshore']:
                grad_thred = 0.001
            else:
                raise ValueError(f'not supported modeling marine_or_land: {self.marine_or_land}')
            
            taper = grad_taper_tensor(nz, nx, 
                                    tapersize=self.grad_mute, 
                                    thred=grad_thred, 
                                    marine_or_land=self.marine_or_land,
                                    device=device)
            grad = grad * taper
        
        # grad mask
        if self.grad_mask is not None:
            if self.grad_mask.shape != grad.shape:
                raise ValueError('Wrong size of grad mask: the size of the mask should be identical to the size of vp model')
            grad = grad * torch.tensor(self.grad_mask, device=device, dtype=dtype)
        
        # apply the inverse Hessian
        if self.forw_illumination and forw is not None:
            span = 40 if min(nz, nx) > 40 else int(min(nz, nx)/2)
            
            forw = smooth2d_tensor(forw, span=span, device=device)
            epsilon = 0.0001
            precond = torch.abs(forw)
            precond = precond / (precond.max() + 1e-5)
            precond = torch.clamp(precond, min=epsilon)
            grad = grad / precond.pow(2)
        
        # smooth the gradient
        if self.grad_smooth > 0:
            if self.marine_or_land in ['marine', 'offshore']: 
                grad_upper = grad[:self.grad_mute,:]
                grad_lower = smooth2d_tensor(grad[self.grad_mute:,:], 
                                          span=self.grad_smooth, 
                                          device=device)
                grad = torch.cat([grad_upper, grad_lower], dim=0)
            else:
                grad = smooth2d_tensor(grad, span=self.grad_smooth, device=device)
        
        # scale the gradient properly
        if self.norm_grad:
            grad = vmax * grad / grad.abs().max()
        
        # depth weight. The deeper the layer, the larer the weight
        if self.depth_weight:
            depth_indices = torch.arange(nz, device=device, dtype=dtype).unsqueeze(1)  # shape: [nz,1]
            depth_weight = 1.05 ** depth_indices  
            grad = grad * depth_weight  
        if self.shmode:
            ## Smooth the gradient near the source location
            if self.src_x is not None:
                for x, z in zip(self.src_x.tolist(), self.src_z.tolist()):
                    block = grad[z-1:z+2, x-2:x+2]
                    grad[z-1:z+2, x-2:x+2] = torch.mean(block)
        return grad