"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

import numpy as np

def bc_pml_freq(nx, nz, dx, dz, pml, vmax, freq, free_surface=True):
    """
    Calculate the complex-valued PML coefficients in frequency domain using
    the wavenumber modification method: k^2 -> k^2 + i*omega*gamma(x,z)
    
    Parameters:
        nx, nz: grid numbers in x and z direction
        dx, dz: grid spacing
        pml: thickness of PML layer
        vmax: maximum velocity
        freq: frequency or array of frequencies in Hz
        free_surface: boolean for free surface condition
    
    Returns:
        Complex-valued damping coefficients for each frequency
        If freq is a scalar: returns a single damping matrix of shape (nz_pml, nx_pml)
        If freq is an array: returns a stack of damping matrices of shape (nfreq, nz_pml, nx_pml)
    """
    nx_pml = nx + 2*pml
    nz_pml = nz + 2*pml
    
    freq = np.atleast_1d(freq)
    nfreq = len(freq)
    
    # Adjustable PML parameters
    R = 1e-1        # Theoretical reflection coefficient
    n = 4           # Polynomial order
    
    base_strength = 0.5    # Base strength factor
    freq_adapt = 0.4       # Frequency adaptive coefficient
    
    # Physical thickness of PML layers
    L_x = pml * dx
    L_z = pml * dz
    
    # Compute optimal attenuation strength
    def compute_d0(v, L, omega):
        # Frequency-adaptive strength factor
        strength_factor = min(base_strength, base_strength * (1.0 / (1.0 + freq_adapt * omega)))
        return strength_factor * (n+1) * v * (-np.log(R)) / (2.0 * L)

    # Initialize damping matrix
    damp_global = np.zeros((nfreq, nz_pml, nx_pml), dtype=complex)
    
    for ifreq in range(nfreq):
        omega = 2.0 * np.pi * freq[ifreq]
        
        d0_x = compute_d0(vmax, L_x, omega) 
        d0_z = compute_d0(vmax, L_z, omega)  
        
        damp_profile_x = np.zeros(pml)
        damp_profile_z = np.zeros(pml)
        
        profile_type = "polynomial"  # "polynomial" or "sine"
        
        for i in range(pml):
            xi_x = (i / (pml-1)) if pml > 1 else 0
            xi_z = (i / (pml-1)) if pml > 1 else 0
            
            if profile_type == "polynomial":
                damp_profile_x[i] = d0_x * (xi_x ** n)
                damp_profile_z[i] = d0_z * (xi_z ** n)
            elif profile_type == "sine":
                damp_profile_x[i] = d0_x * (np.sin(np.pi * xi_x / 2)) ** 2
                damp_profile_z[i] = d0_z * (np.sin(np.pi * xi_z / 2)) ** 2
        
        damp_x = np.zeros((nz_pml, nx_pml))
        damp_z = np.zeros((nz_pml, nx_pml))
        
        # Apply x-direction damping (left/right boundaries)
        for i in range(pml):
            damping_x = damp_profile_x[pml-i-1]  # reversed profile for left boundary
            for iz in range(nz_pml):
                damp_x[iz, i] = damping_x
                damp_x[iz, nx_pml-i-1] = damping_x
        
        # Apply z-direction damping (top/bottom boundaries)
        for i in range(pml):
            damping_z = damp_profile_z[pml-i-1]  # reversed profile for top boundary
            for ix in range(nx_pml):
                if not free_surface:
                    damp_z[i, ix] = damping_z
                damp_z[nz_pml-i-1, ix] = damping_z
        
        # Corner region strategy
        corner_method = "max"  # "max", "mean", or "sum"
        
        for iz in range(nz_pml):
            for ix in range(nx_pml):
                if corner_method == "max":
                    damp_global[ifreq, iz, ix] = max(damp_x[iz, ix], damp_z[iz, ix])
                elif corner_method == "mean":
                    damp_global[ifreq, iz, ix] = (damp_x[iz, ix] + damp_z[iz, ix]) / 2
                elif corner_method == "sum":
                    damp_global[ifreq, iz, ix] = damp_x[iz, ix] + damp_z[iz, ix]
    
    if len(freq) == 1:
        return damp_global[0]
    else:
        return damp_global