
"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""


import numpy as np
import torch
from torch import Tensor
from typing import Optional,Tuple,Union
from uniSI.utils       import gpu2cpu,numpy2tensor
from uniSI.model.base  import AbstractModel
from uniSI.view        import (plot_two_parameter,plot_model)
from uniSI.survey      import Survey

class SHModel(AbstractModel):
    """SH wave model with parameterization of mu (shear modulus) and rho (density)
    
    Parameters:
    --------------
        ox (float),oz(float)        : Origin of the model in x- and z- direction (m)
        nx (int),nz(int)            : Number of grid points in x- and z- direction (m)
        dx (float),dz(float)        : Grid size in x- and z- direction (m)
        mu (array)                  : Shear modulus model with shape (nz,nx) in GPa
        rho (array)                 : Density model with shape (nz,nx) in kg/m³
        mu_bound (tuple,Optional)   : Bounds for the shear modulus model, default None
        rho_bound (tuple,Optional)  : Bounds for the density model, default None
        mu_grad (bool,Optional)     : Flag for gradient of shear modulus model, default is False
        rho_grad (bool,Optional)    : Flag for gradient of density, default is False
        free_surface (bool,Optional): Flag for free surface, default is False
        abc_type (str)              : The type of absorbing boundary condition: PML,jerjan and other
        abc_jerjan_alpha (float)    : The attenuation factor for jerjan boundary condition
        nabc (int)                  : Number of absorbing boundary cells, default is 20
        device (str,Optional)       : The running device
        dtype (dtypes,Optional)     : The dtypes for pytorch variable, default is torch.float32
    """
    def __init__(self,
                ox: float, oz: float,
                nx: int, nz: int,
                dx: float, dz: float,
                vs: Optional[Union[np.array,Tensor]] = None,      
                rho: Optional[Union[np.array,Tensor]] = None,
                vs_bound: Optional[Tuple[float, float]] = None,  
                rho_bound: Optional[Tuple[float, float]] = None,
                vs_grad: Optional[bool] = False,               
                rho_grad: Optional[bool] = False,
                auto_update_rho: Optional[bool] = True,
                auto_update_vs: Optional[bool] = False,
                water_layer_mask: Optional[Union[np.array,Tensor]] = None,
                free_surface: Optional[bool] = False,
                abc_type: Optional[str] = 'PML',
                abc_jerjan_alpha: Optional[float] = 0.0053,
                nabc: Optional[int] = 20,
                device = 'cpu',
                dtype = torch.float32
                ) -> None:
        # initialize the common model parameters
        super().__init__(ox,oz,nx,nz,dx,dz,free_surface,abc_type,abc_jerjan_alpha,nabc,device,dtype)
        
        # initialize the model parameters
        self.pars = ["vs","rho"]   
        self.vs = vs.copy()         
        self.rho = rho.copy()
        self.vs_grad = vs_grad      
        self.rho_grad = rho_grad
        self._parameterization()
        
        self.lower_bound["vs"] = vs_bound[0] if vs_bound is not None else None
        self.lower_bound["rho"] = rho_bound[0] if rho_bound is not None else None
        self.upper_bound["vs"] = vs_bound[1] if vs_bound is not None else None
        self.upper_bound["rho"] = rho_bound[1] if rho_bound is not None else None
        
        self.requires_grad["vs"] = self.vs_grad
        self.requires_grad["rho"] = self.rho_grad
        
        
        # check the input model
        self._check_bounds()
        self.check_dims()
        
        # update parameters using empirical functions
        self.auto_update_rho = auto_update_rho
        self.auto_update_vs = auto_update_vs
        
        if water_layer_mask is not None:
            self.water_layer_mask = numpy2tensor(water_layer_mask,dtype=torch.bool).to(device)
        else:
            self.water_layer_mask = None
            
        # Calculate S-wave velocity
        self.update_mu()
        
    def update_mu(self):
        self.mu = self.vs * self.vs * self.rho  # (m/s)^2 * kg/m^3 = Pa
        return

    def _parameterization(self):

        self.vs = numpy2tensor(self.vs, self.dtype).to(self.device)
        self.rho = numpy2tensor(self.rho, self.dtype).to(self.device)
        self.vs = torch.nn.Parameter(self.vs, requires_grad=self.vs_grad)
        self.rho = torch.nn.Parameter(self.rho, requires_grad=self.rho_grad)
        return
    
    def _plot_mu_rho(self, **kwargs):
        """plot shear modulus and density model"""
        plot_two_parameter(self.mu,self.rho,
                    dx=self.dx,dz=self.dz,**kwargs, model_name="mu_rho")
        return
    
    def _plot(self, var, **kwargs):
        """plot single model parameter"""
        model_data = self.get_model(var)
        plot_model(model_data, title=var, **kwargs)
        return
    
    def set_rho_using_empirical_function(self):
        rho = self.rho.cpu().detach().numpy()
        vs_data = self.vs.cpu().detach().numpy()
        rho_empirical = self.mu.cpu().detach().numpy() / (vs_data * vs_data)
        if self.water_layer_mask is not None:
            mask = self.water_layer_mask.cpu().detach().numpy()
            rho_empirical[mask] = rho[mask]
        rho = numpy2tensor(rho_empirical, self.dtype).to(self.device)
        self.rho = torch.nn.Parameter(rho, requires_grad=self.rho_grad)
        return

    def set_vs_using_empirical_function(self):
        rho = self.rho.cpu().detach().numpy()
        vs_data = self.vs.cpu().detach().numpy()
        vs_empirical = np.sqrt(self.mu.cpu().detach().numpy() / rho)
        if self.water_layer_mask is not None:
            mask = self.water_layer_mask.cpu().detach().numpy()
            vs_empirical[mask] = vs_data[mask]
        vs = numpy2tensor(vs_empirical, self.dtype).to(self.device)
        self.vs = torch.nn.Parameter(vs, requires_grad=self.vs_grad)
        return
    
    def clip_params(self) -> None:
        """Clip the model parameters to the given bounds"""
        for par in self.pars:
            if self.lower_bound[par] is not None and self.upper_bound[par] is not None:
                m = getattr(self, par)
                min_value = self.lower_bound[par]
                max_value = self.upper_bound[par]
                m_temp = m.clone()
                m.data.clamp_(min_value, max_value)
                if self.water_layer_mask is not None:
                    m.data = torch.where(self.water_layer_mask.contiguous(), m_temp.data, m.data)
        return
        
    def forward(self) -> None:
        if self.auto_update_rho and not self.rho_grad:
            self.set_rho_using_empirical_function()
        
        if self.auto_update_vs and not self.vs_grad:
            self.set_vs_using_empirical_function()
            
        self.clip_params()
        self.update_mu() 
        return