"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from .base import Regularization,regular_StepLR
import torch
import numpy as np
from uniSI.utils import numpy2tensor

class TV_1order(Regularization):
    """1st-order Tikhonov Regularization
        math:
            alphax*||L_1 m_x||_p + alphaz*||L_1 m_z||_p 
        Du, Z., et al., 2021. A high-order total-variation regularisation method for full-waveform inversion. Journal of Geophysics and Engineering, 18, 241–252. doi:10.1093/jge/gxab010
    """
    def __init__(self,nx,nz,dx,dz,alphax=0,alphaz=0,step_size=1000,gamma=1,device='cpu',dtype=torch.float32) -> None:
        super().__init__(nx,nz,dx,dz,alphax,alphaz,step_size,gamma,device,dtype)
        # vertical constraint
        L0 = np.diag(1*np.ones(nz),0) + np.diag(-1*np.ones(nz-1),1)
        L0[-1,:] = 0
        self.L0 = numpy2tensor(L0,dtype=torch.float32).to(device)
        # horizontal constraint
        L1 = np.diag(1*np.ones(nx),0) + np.diag(-1*np.ones(nx-1),1)
        L1[-1,:] = 0
        self.L1 = numpy2tensor(L1,dtype=torch.float32).to(device)
        
    def forward(self,m):
        dz,dx  = self.dz/1000,self.dx/1000  # unit: km
        # vertical constraint
        m_norm_z = torch.matmul(self.L0, m) / dz

        # horizontal constraint
        m_norm_x = torch.matmul(self.L1, m.T).T / dx
        
        # update the alpha
        alphax = regular_StepLR(self.iter,self.step_size,self.alphax,self.gamma)
        alphaz = regular_StepLR(self.iter,self.step_size,self.alphaz,self.gamma)
        
        # misfit
        # misfit_norm = _l1_norm(alphax*m_norm_x) + _l1_norm(alphaz*m_norm_z)
        misfit_norm = torch.sum(alphax*torch.abs(m_norm_x) + alphaz*torch.abs(m_norm_z))
        
        self.iter += 1
        return misfit_norm