#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
This code is also modified from Matlab-based 2D visco-acoustic wave equation solver
- Url: https://github.com/navid58/FDFD_ver02.
================================================================================
"""


import numpy as np
import torch
from torch import Tensor
from typing import Optional, Tuple, Union
import time
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable
sys.path.append('../../')
try: from torch import sparse; sparse.spsolve; print("✅ torch.sparse.spsolve available")
except (ImportError, AttributeError): print("❌ torch.sparse.spsolve not available");sys.exit(1)

from uniSI.utils       import gpu2cpu, numpy2tensor
from uniSI.model.base  import AbstractModel
from uniSI.view        import plot_two_parameter, plot_model


# =============================================================================
# SHModel_Freq
# =============================================================================
class SHModel_Freq(AbstractModel):
    _conversion_cache = {}

    def __init__(self,
                 ox: float, oz: float,
                 nx: int, nz: int,
                 dx: float, dz: float,
                 dt: float,
                 vs: Optional[Union[np.array, Tensor]] = None,
                 rho: Optional[Union[np.array, Tensor]] = None,
                 Q: Optional[Union[np.array, Tensor]] = None,
                 wref: Optional[float] = None,
                 tmax: Optional[float] = 1.0,
                 twrap: Optional[float] = 0.1,
                 Sx: Optional[float] = 0.0,
                 Sz: Optional[float] = 0.0,
                 Rx: Optional[float] = 0.0,
                 Rz: Optional[float] = 0.0,
                 fmax: Optional[float] = 100.0,
                 L: Optional[int] = 20,
                 sourceSpectrum: Optional[np.array] = None,
                 freq_zpad: Optional[int] = 0,
                 atten_opt: Optional[str] = 'KF',
                 alpha: Optional[float] = 0.1,
                 vs_bound: Optional[Tuple[float, float]] = None,
                 rho_bound: Optional[Tuple[float, float]] = None,
                 Q_bound: Optional[Tuple[float, float]] = None,
                 vs_grad: Optional[bool] = False,
                 rho_grad: Optional[bool] = False,
                 Q_grad: Optional[bool] = False,
                 auto_update_rho: Optional[bool] = False,
                 auto_update_vs: Optional[bool] = False,
                 auto_update_Q: Optional[bool] = False,
                 water_layer_mask: Optional[Union[np.array, Tensor]] = None,
                 free_surface: Optional[bool] = False,
                 abc_type: Optional[str] = 'PML',
                 nabc: Optional[int] = 20,
                 device='cuda:0',
                 dtype=torch.float64,
                 project_path: Optional[str] = None,
                 ) -> None:
        super().__init__(ox, oz, nx, nz, dx, dz, free_surface, abc_type, None, nabc, device, dtype)

        print(f"abc_type: {abc_type}")
        try: from torch import sparse; sparse.spsolve; #print("✅ torch.sparse.spsolve available")
        except (ImportError, AttributeError): print("❌ torch.sparse.spsolve not available");sys.exit(1)

        self.pars = ["vs", "rho", "Q"]
        self.vs = vs.copy()
        self.rho = rho.copy()
        self.Q = Q.copy()
        self.vs_grad = vs_grad
        self.rho_grad = rho_grad
        self.Q_grad = Q_grad

        self.Sx = Sx
        self.Sz = Sz
        self.Rx = Rx
        self.Rz = Rz
        self.fmax = fmax
        self.L = L
        self.sourceSpectrum = sourceSpectrum
        self.freq_zpad = freq_zpad
        self.atten_opt = atten_opt
        self.project_path = project_path

        self._parameterization()
        self.wref = wref

        self.tmax = tmax
        self.twrap = twrap
        self.alpha = alpha
        self.dt = dt

        self.lower_bound["vs"] = vs_bound[0] if vs_bound is not None else None
        self.lower_bound["rho"] = rho_bound[0] if rho_bound is not None else None
        self.lower_bound["Q"] = Q_bound[0] if Q_bound is not None else None
        self.upper_bound["vs"] = vs_bound[1] if vs_bound is not None else None
        self.upper_bound["rho"] = rho_bound[1] if rho_bound is not None else None
        self.upper_bound["Q"] = Q_bound[1] if Q_bound is not None else None

        self.requires_grad["vs"] = self.vs_grad
        self.requires_grad["rho"] = self.rho_grad
        self.requires_grad["Q"] = self.Q_grad

        self._check_bounds()
        self.check_dims()

        self.auto_update_rho = auto_update_rho
        self.auto_update_vs = auto_update_vs
        self.auto_update_Q = auto_update_Q

        if water_layer_mask is not None:
            self.water_layer_mask = numpy2tensor(water_layer_mask, dtype=torch.bool).to(device)
        else:
            self.water_layer_mask = None

        ## Print the location of data storage
        print(f"Data storage: {self.vs.device}")

    def _parameterization(self):
        self.vs = numpy2tensor(self.vs, self.dtype).to(self.device)
        self.rho = numpy2tensor(self.rho, self.dtype).to(self.device)
        self.Q = numpy2tensor(self.Q, self.dtype).to(self.device)

        print(f"self.device: {self.device}")
        print(f"vs: {self.vs.device}, rho: {self.rho.device}, Q: {self.Q.device}")

        self.vs = torch.nn.Parameter(self.vs, requires_grad=self.vs_grad)
        self.rho = torch.nn.Parameter(self.rho, requires_grad=self.rho_grad)
        self.Q = torch.nn.Parameter(self.Q, requires_grad=self.Q_grad)
        return

    def _tocomplexfield(self, f, wref):
        """
        KF model
          1/vs(ω) = 1/vs + 1/(π·vs·Q)·ln(ω_ref/ω) + i/(2·vs·Q)
        vs(ω) = 1 / [1/vs + 1/(π·vs·Q)·ln(ω_ref/ω) + i/(2·vs·Q)]
        """
        w = 2 * np.pi * f
        vs_val = self.vs.detach().clone() if self.vs_grad else self.vs
        Q_val = self.Q.detach().clone() if self.Q_grad else self.Q

        vs_val = vs_val.to(torch.complex128)
        Q_val = Q_val.to(torch.complex128)
        w_tensor = torch.tensor(w, dtype=torch.complex128, device=vs_val.device)
        wref_tensor = torch.tensor(wref, dtype=torch.complex128, device=vs_val.device)
        term = 1.0/vs_val + 1.0/(torch.pi * vs_val * Q_val)*torch.log(wref_tensor / w_tensor) + 1j/(2 * vs_val * Q_val)
        return 1.0/term

    @staticmethod
    def complex_shear_velocity(vs: Tensor, Q: Tensor, w: float, wref: float, atten_opt: str = 'KF') -> Tensor:
        if atten_opt == 'KF':
            vs = vs.to(torch.complex128)
            Q = Q.to(torch.complex128)
            w_tensor = torch.as_tensor(w, dtype=torch.complex128, device=vs.device)
            wref_tensor = torch.as_tensor(wref, dtype=torch.complex128, device=vs.device)
            term = 1.0/vs + 1.0/(torch.pi * vs * Q)*torch.log(wref_tensor / w_tensor) + 1j/(2 * vs * Q)
            return 1.0/term
        elif atten_opt == 'no_atten':
            return vs.to(torch.complex128)
        else:
            raise ValueError("Unknown attenuation option!")

    @staticmethod
    def shear_modulus(vs: Tensor, rho: Tensor, Q: Tensor, w: float, wref: float, atten_opt: str = 'KF') -> Tensor:
        """
          μ(ω) = ρ · (vs(ω))²
        """
        if atten_opt == 'KF':
            vs_new = SHModel_Freq.complex_shear_velocity(vs, Q, w, wref, atten_opt)
            return rho.to(torch.complex128) * (vs_new**2)
        elif atten_opt == 'no_atten':
            return rho.to(torch.complex128) * (vs.to(torch.complex128)**2)
        else:
            raise ValueError("Unknown attenuation option!")

    def compute_shear_modulus(self, w: float, wref: float, atten_opt: str = 'KF') -> Tensor:
        return SHModel_Freq.shear_modulus(self.vs, self.rho, self.Q, w, wref, atten_opt)

    def _plot_vs_Q(self, **kwargs):
        plot_two_parameter(self.vs, self.Q,
                             dx=self.dx, dz=self.dz, model_name="vs_Q", **kwargs)
        return

    def _plot(self, var, **kwargs):
        model_data = self.get_model(var)
        plot_model(model_data, title=var, **kwargs)
        return

    def set_rho_using_empirical_function(self):
        rho = self.rho.cpu().detach().numpy()
        vs = self.vs.cpu().detach().numpy()
        rho_empirical = np.power(vs, 0.25) * 310
        if self.water_layer_mask is not None:
            mask = self.water_layer_mask.cpu().detach().numpy()
            rho_empirical[mask] = rho[mask]
        rho = numpy2tensor(rho_empirical, self.dtype).to(self.device)
        self.rho = torch.nn.Parameter(rho, requires_grad=self.rho_grad)
        return

    def set_vs_using_empirical_function(self):

        rho = self.rho.cpu().detach().numpy()
        vs = self.vs.cpu().detach().numpy()
        vs_empirical = np.power(rho / 310, 4)
        if self.water_layer_mask is not None:
            grad_mask = self.water_layer_mask.cpu().detach().numpy()
            vs_empirical[grad_mask] = vs[grad_mask]
        vs = numpy2tensor(vs_empirical, self.dtype).to(self.device)
        self.vs = torch.nn.Parameter(vs, requires_grad=self.vs_grad)
        return

    def clip_params(self) -> None:

        for par in self.pars:
            if self.lower_bound[par] is not None and self.upper_bound[par] is not None:
                m = getattr(self, par)
                min_value = self.lower_bound[par]
                max_value = self.upper_bound[par]
                m_temp = m.clone()
                m.data.clamp_(min_value, max_value)
                if self.water_layer_mask is not None:
                    m.data = torch.where(self.water_layer_mask.contiguous(), m_temp.data, m.data)
        return

    def forward(self) -> Tuple:

        if self.auto_update_rho and not self.rho_grad:
            self.set_rho_using_empirical_function()
        if self.auto_update_vs and not self.vs_grad:
            self.set_vs_using_empirical_function()
        self.clip_params()
        return

 
    @staticmethod
    def ext_pml(v: Tensor, L: int, top_bc: str) -> Tensor:
        nz, nx = v.shape
        if top_bc == 'PML':
            top_pad    = v[0, :].unsqueeze(0).repeat(L, 1)
            left_pad   = v[:, 0].unsqueeze(1).repeat(1, L)
            right_pad  = v[:, -1].unsqueeze(1).repeat(1, L)
            bottom_pad = v[-1, :].unsqueeze(0).repeat(L, 1)
            tl = torch.full((L, L), v[0, 0].detach().item(), dtype=v.dtype, device=v.device)
            tr = torch.full((L, L), v[0, -1].detach().item(), dtype=v.dtype, device=v.device)
            bl = torch.full((L, L), v[-1, 0].detach().item(), dtype=v.dtype, device=v.device)
            br = torch.full((L, L), v[-1, -1].detach().item(), dtype=v.dtype, device=v.device)
            top_block    = torch.cat([tl, top_pad, tr], dim=1)
            middle_block = torch.cat([left_pad, v, right_pad], dim=1)
            bottom_block = torch.cat([bl, bottom_pad, br], dim=1)
            ve = torch.cat([top_block, middle_block, bottom_block], dim=0)
        elif top_bc in ['Dirichlet', 'Neumann']:
            left_pad  = v[:, 0].unsqueeze(1).repeat(1, L)
            right_pad = v[:, -1].unsqueeze(1).repeat(1, L)
            bottom_pad = v[-1, :].unsqueeze(0).repeat(L, 1)
            temp = torch.cat([left_pad, v, right_pad], dim=1)
            bottom_row = torch.full((L, temp.shape[1]), v[-1, 0].detach().item(), dtype=v.dtype, device=v.device)
            temp = torch.cat([temp, bottom_row], dim=0)
            ve = torch.cat([temp[0:1, :], temp], dim=0)
        else:
            raise ValueError("Unknown top_bc option!")
        return ve

    @staticmethod
    def convert_complex_to_real_cached(A_complex: Tensor, b_complex: Tensor) -> Tuple[Tensor, Tensor]:
        """
        complex -> real
          [A_R  -A_I; A_I   A_R] [x_R; x_I] = [b_R; b_I]
        """
        n = A_complex.size(0)
        A_coo = A_complex.to_sparse_coo()
        indices = A_coo.indices()
        key = (A_complex.shape, indices.cpu().numpy().tobytes())
        if key in SHModel_Freq._conversion_cache:
            cache_entry = SHModel_Freq._conversion_cache[key]
            new_row = cache_entry['new_row']
            new_col = cache_entry['new_col']
        else:
            row = indices[0]
            col = indices[1]
            new_row = torch.cat([row, row, row+n, row+n])
            new_col = torch.cat([col, col+n, col, col+n])
            SHModel_Freq._conversion_cache[key] = {'new_row': new_row, 'new_col': new_col}
        values = A_coo.values()
        new_values = torch.cat([values.real, -values.imag, values.imag, values.real])
        A_real_eq = torch.sparse_coo_tensor(torch.stack([new_row, new_col], dim=0),
                                              new_values,
                                              size=(2*n, 2*n),
                                              dtype=new_values.dtype,
                                              device=A_complex.device)
        A_real_eq = A_real_eq.to_sparse_csr()
        if b_complex.dim() == 1:
            b_real_eq = torch.cat([b_complex.real, b_complex.imag])
        elif b_complex.dim() == 2:
            b_real_eq = torch.cat([b_complex.real, b_complex.imag], dim=0)
        else:
            raise ValueError("b_complex must be 1D or 2D tensor.")
        return A_real_eq, b_real_eq

    @staticmethod
    def rho_stg(rho: Tensor, m: int, n: int) -> Tensor:

        b = 1.0 / rho
        nz, nx = rho.shape
        buNW = b[m-1, n-1] if (m-1 >= 0 and n-1 >= 0) else b[m, n]
        buW  = b[m, n-1]   if (n-1 >= 0) else b[m, n]
        buSW = b[m+1, n-1] if (m+1 < nz and n-1 >= 0) else b[m, n]
        buN  = b[m-1, n]   if (m-1 >= 0) else b[m, n]
        buS  = b[m+1, n]   if (m+1 < nz) else b[m, n]
        buNE = b[m-1, n+1] if (m-1 >= 0 and n+1 < nx) else b[m, n]
        buE  = b[m, n+1]   if (n+1 < nx) else b[m, n]
        buSE = b[m+1, n+1] if (m+1 < nz and n+1 < nx) else b[m, n]
        bu = torch.zeros(9, dtype=rho.dtype, device=rho.device)
        bu[4] = b[m, n]
        bu[0] = 0.5 * (buNW + bu[4])
        bu[1] = 0.5 * (buW  + bu[4])
        bu[2] = 0.5 * (buSW + bu[4])
        bu[3] = 0.5 * (buN  + bu[4])
        bu[5] = 0.5 * (buS  + bu[4])
        bu[6] = 0.5 * (buNE + bu[4])
        bu[7] = 0.5 * (buE  + bu[4])
        bu[8] = 0.5 * (buSE + bu[4])
        return bu

    @staticmethod
    def stencil_index(nz: int, nx: int, m: int, n: int) -> list:

        indices = []
        for dm, dn in [(-1,-1), (0,-1), (1,-1),
                       (-1, 0), (0, 0), (1, 0),
                       (-1, 1), (0, 1), (1, 1)]:
            mm = m + dm
            nn = n + dn
            if mm < 0 or mm >= nz or nn < 0 or nn >= nx:
                indices.append(-1)
            else:
                indices.append(mm + nn * nz)
        return indices

    @classmethod
    def imp_nine(cls, w: Tensor, dx: float, L: int, alpha: float,
                 vs: Tensor, rho: Tensor, mu: Tensor, top_bc: str) -> torch.Tensor:
    
        nz, nx = vs.shape
        n_total = nz * nx
        A = torch.zeros((n_total, n_total), dtype=torch.complex128, device=vs.device)
        
        I_list = []
        J_list = []
        V_list = []
        
        inds = torch.arange(L, dtype=torch.float64, device=vs.device)
        damp = alpha * (1 - torch.cos((L - inds) * torch.pi / (2 * L)))
        if nx >= 2 * L:
            damp_z = torch.cat([damp, torch.zeros(nx - 2 * L, dtype=torch.float64, device=vs.device), torch.flip(damp, dims=[0])])
        else:
            damp_z = torch.cat([damp, torch.flip(damp, dims=[0])])
        if top_bc == 'PML':
            if nz >= 2 * L:
                damp_x = torch.cat([damp, torch.zeros(nz - 2 * L, dtype=torch.float64, device=vs.device), torch.flip(damp, dims=[0])])
            else:
                damp_x = torch.cat([damp, torch.flip(damp, dims=[0])])
        else:
            if nz >= L + len(torch.flip(damp, dims=[0])):
                damp_x = torch.cat([torch.zeros(L, dtype=torch.float64, device=vs.device),
                                     torch.zeros(nz - 2 * L, dtype=torch.float64, device=vs.device),
                                     torch.flip(damp, dims=[0])])
            else:
                damp_x = torch.zeros(nz, dtype=torch.float64, device=vs.device)

        def add_entry(i_val, j_val, val):
            if i_val != -1 and j_val != -1:
                I_list.append(i_val)
                J_list.append(j_val)
                V_list.append(val)

        sqrt2 = torch.sqrt(torch.tensor(2.0, dtype=torch.float64, device=vs.device))
        r = cls.stencil_index(nz, nx, 0, 0)
        tmp = -w / vs[0, 0]
        add_entry(r[4], r[4], -1/dx - 1j*tmp*sqrt2/4)
        add_entry(r[4], r[8],  1/dx - 1j*tmp*sqrt2/4)
        add_entry(r[4], r[5], -1j*tmp*sqrt2/4)
        add_entry(r[4], r[7], -1j*tmp*sqrt2/4)

        for m in range(1, nz-1):
            r = cls.stencil_index(nz, nx, m, 0)
            tmp = -w / vs[m, 0]
            add_entry(r[4], r[4], -2j*tmp/dx + 1j/(tmp*dx**3) + tmp**2 - 1.5/dx**2)
            add_entry(r[4], r[7],  2j*tmp/dx - 1j/(tmp*dx**3) + tmp**2 - 1.5/dx**2)
            add_entry(r[4], r[3], -1j/(2*tmp*dx**3) + 3/(4*dx**2))
            add_entry(r[4], r[5], -1j/(2*tmp*dx**3) + 3/(4*dx**2))
            add_entry(r[4], r[8],  1j/(2*tmp*dx**3) + 3/(4*dx**2))
            add_entry(r[4], r[6],  1j/(2*tmp*dx**3) + 3/(4*dx**2))

        for n_val in range(1, L):
            for m in range(1, nz-1):
                Ez = 1 + 1j*damp_z[n_val]/w
                Ez_L = 0.5*(2 + 1j*(damp_z[n_val] + damp_z[n_val-1])/w)
                if n_val+1 < len(damp_z):
                    Ez_R = 0.5*(2 + 1j*(damp_z[n_val] + damp_z[n_val+1])/w)
                else:
                    Ez_R = 0.5*(2 + 1j*(damp_z[n_val] + damp_z[n_val])/w)
                Ex = 1 + 1j*damp_x[m]/w
                Ex_T = 0.5*(2 + 1j*(damp_x[m] + damp_x[m-1])/w)
                if m+1 < len(damp_x):
                    Ex_B = 0.5*(2 + 1j*(damp_x[m] + damp_x[m+1])/w)
                else:
                    Ex_B = 0.5*(2 + 1j*(damp_x[m] + damp_x[m])/w)
                r = cls.stencil_index(nz, nx, m, n_val)
                factor = 1/(rho[m, n_val]*dx**2)
                add_entry(r[4], r[1], 1/(Ez*Ez_L)*factor)
                add_entry(r[4], r[3], 1/(Ex*Ex_T)*factor)
                term = (w*w/mu[m, n_val] - 1/(Ex*Ex_T*rho[m, n_val]*dx**2)
                        - 1/(Ex*Ex_B*rho[m, n_val]*dx**2)
                        - 1/(Ez*Ez_L*rho[m, n_val]*dx**2)
                        - 1/(Ez*Ez_R*rho[m, n_val]*dx**2))
                add_entry(r[4], r[4], term)
                add_entry(r[4], r[5], 1/(Ex*Ex_B)*factor)
                add_entry(r[4], r[7], 1/(Ez*Ez_R)*factor)

        r = cls.stencil_index(nz, nx, nz-1, 0)
        tmp = -w / vs[nz-1, 0]
        add_entry(r[4], r[4], 1/dx + 1j*tmp*sqrt2/4)
        add_entry(r[4], r[6], -1/dx + 1j*tmp*sqrt2/4)
        add_entry(r[4], r[3], 1j*tmp*sqrt2/4)
        add_entry(r[4], r[7], 1j*tmp*sqrt2/4)

        if top_bc == 'Neumann':
            for n in range(1, nx-1):
                r = cls.stencil_index(nz, nx, 0, n)
                tmp = -w / vs[0, n]
                add_entry(r[4], r[4], -2/dx**2 + tmp**2)
                add_entry(r[4], r[5], 2/dx**2)
        elif top_bc == 'Dirichlet':
            for n in range(1, nx-1):
                r = cls.stencil_index(nz, nx, 0, n)
                tmp = -w / vs[0, n]
                add_entry(r[4], r[4], -2/dx**2 + tmp**2)
                add_entry(r[4], r[5], 0)

        LL = 1 if top_bc in ['Dirichlet','Neumann'] else L
        a = 0.5461
        b_const = 0.6248
        c = 0.25*(1-b_const)
        for n in range(L, nx-L):
            for m in range(LL, nz-L):
                r = cls.stencil_index(nz, nx, m, n)
                bu = cls.rho_stg(rho, m, n)
                add_entry(r[4], r[0], (1-a)*bu[0]/(2*dx**2))
                add_entry(r[4], r[1], (w*w)*c/mu[m,n] + a*bu[1]/(dx**2))
                add_entry(r[4], r[2], (1-a)*bu[2]/(2*dx**2))
                add_entry(r[4], r[3], (w*w)*c/mu[m-1,n] + a*bu[3]/(dx**2))
                middle_val = (w*w*(b_const/mu[m,n]) - a*(bu[1]+bu[3]+bu[5]+bu[7])/(dx**2)
                              - 0.5*(1-a)*(bu[0]+bu[2]+bu[6]+bu[8])/(dx**2))
                add_entry(r[4], r[4], middle_val)
                add_entry(r[4], r[5], (w*w)*c/mu[m+1,n] + a*bu[5]/(dx**2))
                add_entry(r[4], r[6], (1-a)*bu[6]/(2*dx**2))
                add_entry(r[4], r[7], (w*w)*c/mu[m,n+1] + a*bu[7]/(dx**2))
                add_entry(r[4], r[8], (1-a)*bu[8]/(2*dx**2))
        
        for n in range(1, nx-1):
            r = cls.stencil_index(nz, nx, nz-1, n)
            tmp = w / vs[nz-1, n]
            add_entry(r[4], r[4], 2j*tmp/dx - 1j/(tmp*dx**3) + tmp**2 - 1.5/dx**2)
            add_entry(r[4], r[3], -2j*tmp/dx + 1j/(tmp*dx**3) + tmp**2 - 1.5/dx**2)
            add_entry(r[4], r[1], 1j/(2*tmp*dx**3) + 3/(4*dx**2))
            add_entry(r[4], r[7], 1j/(2*tmp*dx**3) + 3/(4*dx**2))
            add_entry(r[4], r[0], -1j/(2*tmp*dx**3) + 3/(4*dx**2))
            add_entry(r[4], r[6], -1j/(2*tmp*dx**3) + 3/(4*dx**2))
        
        for n in range(1, nx-1):
            for m in range(nz-L, nz-1):
                Ez = 1 + 1j*damp_z[n]/w
                Ez_L = 0.5*(2+1j*(damp_z[n]+damp_z[n-1])/w)
                if n+1 < len(damp_z):
                    Ez_R = 0.5*(2+1j*(damp_z[n]+damp_z[n+1])/w)
                else:
                    Ez_R = 0.5*(2+1j*(damp_z[n]+damp_z[n])/w)
                Ex = 1 + 1j*damp_x[m]/w
                Ex_T = 0.5*(2+1j*(damp_x[m]+damp_x[m-1])/w)
                if m+1 < len(damp_x):
                    Ex_B = 0.5*(2+1j*(damp_x[m]+damp_x[m+1])/w)
                else:
                    Ex_B = 0.5*(2+1j*(damp_x[m]+damp_x[m])/w)
                r = cls.stencil_index(nz, nx, m, n)
                factor = 1/(rho[m,n]*dx**2)
                add_entry(r[4], r[1], 1/(Ez*Ez_L)*factor)
                add_entry(r[4], r[3], 1/(Ex*Ex_T)*factor)
                term = (w*w/mu[m,n]-1/(Ex*Ex_T*rho[m,n]*dx**2)
                        -1/(Ex*Ex_B*rho[m,n]*dx**2)-1/(Ez*Ez_L*rho[m,n]*dx**2)
                        -1/(Ez*Ez_R*rho[m,n]*dx**2))
                add_entry(r[4], r[4], term)
                add_entry(r[4], r[5], 1/(Ex*Ex_B)*factor)
                add_entry(r[4], r[7], 1/(Ez*Ez_R)*factor)
        
        r = cls.stencil_index(nz, nx, 0, nx-1)
        tmp = -w / vs[0,nx-1]
        add_entry(r[4], r[4], -1/dx - 1j*tmp*sqrt2/4)
        add_entry(r[4], r[2],  1/dx - 1j*tmp*sqrt2/4)
        add_entry(r[4], r[5], -1j*tmp*sqrt2/4)
        add_entry(r[4], r[1], -1j*tmp*sqrt2/4)
        
        for m in range(1, nz-1):
            r = cls.stencil_index(nz, nx, m, nx-1)
            tmp = w / vs[m,nx-1]
            add_entry(r[4], r[4], 2j*tmp/dx - 1j/(tmp*dx**3) + tmp**2 - 1.5/dx**2)
            add_entry(r[4], r[1], -2j*tmp/dx + 1j/(tmp*dx**3) + tmp**2 - 1.5/dx**2)
            add_entry(r[4], r[3], 1j/(2*tmp*dx**3)+3/(4*dx**2))
            add_entry(r[4], r[5], 1j/(2*tmp*dx**3)+3/(4*dx**2))
            add_entry(r[4], r[2], -1j/(2*tmp*dx**3)+3/(4*dx**2))
            add_entry(r[4], r[0], -1j/(2*tmp*dx**3)+3/(4*dx**2))
        
        for n in range(nx-L, nx-1):
            for m in range(1, nz-1):
                Ez = 1 + 1j*damp_z[n]/w
                Ez_L = 0.5*(2+1j*(damp_z[n]+damp_z[n-1])/w)
                if n+1 < len(damp_z):
                    Ez_R = 0.5*(2+1j*(damp_z[n]+damp_z[n+1])/w)
                else:
                    Ez_R = 0.5*(2+1j*(damp_z[n]+damp_z[n])/w)
                Ex = 1 + 1j*damp_x[m]/w
                Ex_T = 0.5*(2+1j*(damp_x[m]+damp_x[m-1])/w)
                if m+1 < len(damp_x):
                    Ex_B = 0.5*(2+1j*(damp_x[m]+damp_x[m+1])/w)
                else:
                    Ex_B = 0.5*(2+1j*(damp_x[m]+damp_x[m])/w)
                r = cls.stencil_index(nz, nx, m, n)
                factor = 1/(rho[m,n]*dx**2)
                add_entry(r[4], r[1], 1/(Ez*Ez_L)*factor)
                add_entry(r[4], r[3], 1/(Ex*Ex_T)*factor)
                term = (w*w/mu[m,n]-1/(Ex*Ex_T*rho[m,n]*dx**2)
                        -1/(Ex*Ex_B*rho[m,n]*dx**2)-1/(Ez*Ez_L*rho[m,n]*dx**2)
                        -1/(Ez*Ez_R*rho[m,n]*dx**2))
                add_entry(r[4], r[4], term)
                add_entry(r[4], r[5], 1/(Ex*Ex_B)*factor)
                add_entry(r[4], r[7], 1/(Ez*Ez_R)*factor)
        
        r = cls.stencil_index(nz, nx, nz-1, nx-1)
        tmp = -w / vs[nz-1, nx-1]
        add_entry(r[4], r[4], 1/dx + 1j*tmp*sqrt2/4)
        add_entry(r[4], r[0], -1/dx + 1j*tmp*sqrt2/4)
        add_entry(r[4], r[3], 1j*tmp*sqrt2/4)
        add_entry(r[4], r[1], 1j*tmp*sqrt2/4)

        for idx in range(len(I_list)):
            i_val = I_list[idx]
            j_val = J_list[idx]
            val = V_list[idx]
            A[i_val, j_val] = val

        A = A.conj()
        return A

    @staticmethod
    def four2time(pf_p: Tensor, tmax: float, twrap: float, freq_zpad: int) -> Tuple[Tensor, Tensor]:

        dim0, nf, ns = pf_p.shape
        pf_neg = torch.conj(torch.flip(pf_p, dims=[1]))
        zero_freq = torch.zeros((dim0, 1, ns), dtype=torch.complex128, device=pf_p.device)
        pf = torch.cat((pf_neg, zero_freq, pf_p), dim=1)
        pad_zeros = torch.zeros((dim0, freq_zpad, ns), dtype=torch.complex128, device=pf_p.device)
        pf_pad = torch.cat((pad_zeros, pf, pad_zeros), dim=1)
        pf_pad = torch.fft.ifftshift(pf_pad, dim=1)
        pt = torch.real(torch.fft.ifft(pf_pad, dim=1))
        t = torch.linspace(0, tmax, 2*(nf+freq_zpad)+1, dtype=torch.float64, device=pf_p.device)
        undamp = torch.exp(twrap*t/tmax).unsqueeze(0).unsqueeze(2)
        pt = pt * undamp
        print("Frequency domain to time domain conversion is done!")
        return pt, t



    def fdfd_setup(self):
        """
        Set up the parameters for FDFD simulation:
          - Extend the model (PML/boundary extension).
          - Build the source matrix and receiver indices.
          - Construct the frequency vector and retrieve the source spectrum.
        Stores key parameters in self.
        """
        device = self.vs.device
        self.device = device

        # Extend model with boundaries
        self.vs_e   = SHModel_Freq.ext_pml(self.vs, self.L, self.abc_type)
        self.rho_e  = SHModel_Freq.ext_pml(self.rho, self.L, self.abc_type)
        self.Q_e    = SHModel_Freq.ext_pml(self.Q,   self.L, self.abc_type)
        ext_nz, ext_nx = self.vs_e.shape
        self.ext_nz = ext_nz
        self.ext_nx = ext_nx
        self.n_total = ext_nz * ext_nx
        self.dx_val = self.dx  # assume dx == dz

        # Compute offsets so that the original model remains in the same relative position
        if self.abc_type == 'PML':
            self.offset_row = self.L
            self.offset_col = self.L
        elif self.abc_type in ['Dirichlet', 'Neumann']:
            self.offset_row = 1
            self.offset_col = self.L
        else:
            self.offset_row = 0
            self.offset_col = 0

        # Build source matrix s_mat of shape (n_total, ns)
        Sx_tensor = torch.tensor(self.Sx, dtype=torch.long, device=device) + self.offset_col
        Sz_tensor = torch.tensor(self.Sz, dtype=torch.long, device=device) + self.offset_row
        ns = Sx_tensor.numel()
        self.ns = ns
        s_mat = torch.zeros((self.n_total, ns), dtype=torch.complex128, device=device)
        idx = Sz_tensor + Sx_tensor * ext_nz
        s_mat[idx, torch.arange(ns, device=device)] = 1.0
        self.s_mat = s_mat

        # Build receiver indices (rec_ind)
        if not isinstance(self.Rx, torch.Tensor):
            self.Rx = torch.tensor(self.Rx, dtype=torch.long, device=device)
        if not isinstance(self.Rz, torch.Tensor):
            self.Rz = torch.tensor(self.Rz, dtype=torch.long, device=device)
        rec_ind = self.Rz + self.Rx * ext_nz
        rec_ind = rec_ind + self.offset_row + self.offset_col * ext_nz
        self.rec_ind = rec_ind
        self.nr = rec_ind.numel()

        # Construct the frequency vector and retrieve the source spectrum.
        df = 1.0 / self.tmax
        f = torch.arange(df, self.fmax + df, df, dtype=torch.float64, device=device)
        self.f = f
        self.nf = f.shape[0]
        self.fs_wavelet = self.sourceSpectrum  # expected to be a complex tensor of length nf


        

    def fdfd_oneFreq(self, k: int, sparse_solving=False) -> torch.Tensor:
        """
        Solve the forward problem for a single frequency index k.

        Parameters:
            k (int): Frequency index.

        Returns:
            sol (Tensor): The computed wavefield solution for frequency k.
        """
        try: from torch import sparse; sparse.spsolve; #print("✅ torch.sparse.spsolve available")
        except (ImportError, AttributeError): print("❌ torch.sparse.spsolve not available");sys.exit(1)
        
        device = self.device
        freq_k = self.f[k].item()
        w_val = 2 * np.pi * freq_k + 1j * (self.twrap / self.tmax)
        w_val_tensor = torch.as_tensor(w_val, dtype=torch.complex128, device=device)
        mu = SHModel_Freq.shear_modulus(self.vs_e, self.rho_e, self.Q_e, 2*np.pi * freq_k, self.wref, self.atten_opt)
        A = SHModel_Freq.imp_nine(w_val_tensor, self.dx_val, self.L, self.alpha, self.vs_e, self.rho_e, mu, self.abc_type)
        RHS = self.s_mat * self.fs_wavelet[k]
        
        if sparse_solving == True:
            try:
                print("Using sparse solver for frequency index", k)
                # Convert complex system to real system for sparse solving
                n = A.shape[0]
                
                # Check RHS dimensions and handle properly
                if RHS.dim() == 1:
                    # RHS is 1D vector
                    num_rhs = 1
                    RHS = RHS.unsqueeze(1)  # Make it [n, 1]
                else:
                    # RHS is 2D matrix [n, num_rhs]
                    num_rhs = RHS.shape[1]
                
                print(f"Matrix size: {n}x{n}, Number of RHS: {num_rhs}")
                
                # Extract real and imaginary parts
                A_real = A.real
                A_imag = A.imag
                RHS_real = RHS.real  # [n, num_rhs]
                RHS_imag = RHS.imag  # [n, num_rhs]
                
                # Build expanded real system:
                # [A_real  -A_imag] [x_real]   [RHS_real]
                # [A_imag   A_real] [x_imag] = [RHS_imag]
                A_expanded = torch.zeros(2*n, 2*n, dtype=torch.float64, device=device)
                A_expanded[:n, :n] = A_real      # top-left: A_real
                A_expanded[:n, n:] = -A_imag     # top-right: -A_imag
                A_expanded[n:, :n] = A_imag      # bottom-left: A_imag
                A_expanded[n:, n:] = A_real      # bottom-right: A_real
                
                # Handle multiple RHS properly
                RHS_expanded = torch.zeros(2*n, num_rhs, dtype=torch.float64, device=device)
                RHS_expanded[:n, :] = RHS_real   # top half: RHS_real
                RHS_expanded[n:, :] = RHS_imag   # bottom half: RHS_imag
                
                # Convert to sparse format and solve
                A_sparse = A_expanded.to_sparse_csr()
                
                # Solve for each RHS column
                sol_expanded_list = []
                for rhs_idx in range(num_rhs):
                    rhs_col = RHS_expanded[:, rhs_idx]
                    sol_col = sparse.spsolve(A_sparse, rhs_col)
                    sol_expanded_list.append(sol_col)
                
                # Stack solutions
                sol_expanded = torch.stack(sol_expanded_list, dim=1)  # [2*n, num_rhs]
                
                # Convert back to complex solution
                sol_real = sol_expanded[:n, :]   # [n, num_rhs]
                sol_imag = sol_expanded[n:, :]   # [n, num_rhs]
                sol = torch.complex(sol_real, sol_imag)  # [n, num_rhs]
                
                # If original RHS was 1D, return 1D solution
                if num_rhs == 1:
                    sol = sol.squeeze(1)
                
                # Cleanup intermediate variables
                del A_expanded, RHS_expanded, A_sparse, sol_expanded, sol_expanded_list
                del A_real, A_imag, RHS_real, RHS_imag, sol_real, sol_imag
                
            except Exception as e:
                print(f"Warning: Sparse solver failed ({e}), falling back to dense solver")
                sol = torch.linalg.solve(A, RHS)
                
        else:
            sol = torch.linalg.solve(A, RHS)
        
        del A, RHS, mu, w_val_tensor
        torch.cuda.empty_cache()
        return sol

    def simulate_fdfd(self, iftraning: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        FDFD forward simulation:
          1. Extend the model.
          2. Build the source matrix and receiver indices.
          3. Construct the frequency vector and source spectrum.
          4. Solve the forward problem for each frequency.
          5. Transform the frequency-domain fields to time-domain.

        Returns:
          full_time_field: Full time-domain wavefield (extended_nz, nt, ns)
          rec_pf:          Receiver frequency-domain wavefield (nr, nf, ns)
          rec_time_field:  Receiver time-domain wavefield (nr, nt, ns)
          t:               Time axis vector
        """
        #self.forward()  # update model parameters if necessary

        # Setup all necessary parameters and store them in self.
        self.fdfd_setup()
        device = self.device
        ns = self.ns
        nf = self.nf

        full_pf = torch.zeros((self.n_total, nf, ns), dtype=torch.complex128, device=device)
        rec_pf  = torch.zeros((self.nr, nf, ns), dtype=torch.complex128, device=device)
        # Loop over frequencies solving the forward problem
        ifsparse = True if (not iftraning) else False
        for k in range(nf):
            sol = self.fdfd_oneFreq(k,sparse_solving=ifsparse)
            full_pf[:, k, :] = sol
            rec_pf[:, k, :] = sol[self.rec_ind, :]
            if not iftraning:
                print(f"Frequency {k+1}/{nf} ({self.f[k].item():.2f} Hz) solved.\n-----------")
            elif k % 10 == 0:
                print(f"Frequency {k+1}/{nf} ({self.f[k].item():.2f} Hz) solved.")

        if not iftraning:
            print("FDFD for all frequencies is done!")


        # Transform frequency-domain fields to time-domain
        full_time_field, t = SHModel_Freq.four2time(full_pf, self.tmax, self.twrap, self.freq_zpad)
        rec_time_field, _ = SHModel_Freq.four2time(rec_pf, self.tmax, self.twrap, self.freq_zpad)

        # If not in training mode, create a GIF animation (only for the first source shot)
        if not iftraning:
            shot_id = 4
            source_x = self.Sx[shot_id] 
            source_z = self.Sz[shot_id]
            full_time_field[:, :, shot_id] = full_time_field[:, :, shot_id] / torch.max(full_time_field)
        
            # Define extent based on model dimensions
            extent = [0, self.nx, self.nz, 0]

            import matplotlib.pyplot as plt
            from mpl_toolkits.axes_grid1 import make_axes_locatable

            # Calculate figure size based on the velocity model dimensions (assumes self.vs shape is (nz, nx))
            ny, nx = self.vs.shape
            fig_width = 8  # Base width
            fig_height = (ny / nx) * fig_width * 2.2  # Height is scaled to leave space for 2 subplots

            # Create two subplots: top for the velocity model and bottom for the wavefield animation
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, fig_height))
            plt.subplots_adjust(hspace=0.3)  # Adjust vertical spacing between subplots

            # Plot the velocity model on the top subplot
            im_vs = ax1.imshow(self.vs.detach().cpu().numpy(), extent=extent, cmap='coolwarm', aspect='equal')
            ax1.set_title("Velocity Model")
            ax1.set_xlabel("x (m)")
            ax1.set_ylabel("z (m)")
            divider1 = make_axes_locatable(ax1)
            cax1 = divider1.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im_vs, cax=cax1, label="P-wave velocity (m/s)")

            # Plot the initial wavefield snapshot on the bottom subplot
            data0 = full_time_field[:, 0, shot_id].detach().cpu().numpy().reshape((self.ext_nz, self.ext_nx), order='F')
            # Extract the region without PML
            data0 = data0[self.offset_row:self.ext_nz - self.nabc,
                          self.offset_col:self.ext_nx - self.offset_col]
            vmax = 0.9
            im = ax2.imshow(data0, extent=extent, cmap='seismic', vmin=-vmax, vmax=vmax, interpolation='bilinear')
            ax2.set_xlabel("x (m)")
            ax2.set_ylabel("z (m)")
            divider2 = make_axes_locatable(ax2)
            cax2 = divider2.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im, cax=cax2, label="Amplitude")
            ax2.set_title("Wavefield")
            
            # Plot source location on both subplots
            ax1.scatter(source_x, source_z, c='r', marker='*', s=100, label='Source')
            ax2.scatter(source_x, source_z, c='r', marker='*', s=100, label='Source')
            
            # Time annotation on the wavefield axes
            time_text = ax2.text(0.98, 0.02, '', transform=ax2.transAxes,
                                 horizontalalignment='right', verticalalignment='bottom')

            def update_frame(i):
                data_i = full_time_field[:, i, shot_id].detach().cpu().numpy().reshape((self.ext_nz, self.ext_nx), order='F')
                data_i = data_i[self.offset_row:self.ext_nz - self.nabc,
                                self.offset_col:self.ext_nx - self.offset_col]
                im.set_data(data_i)
                time_text.set_text(f'Time = {t[i].item():.2f} s')
                return [im]

            # Total simulation time is 9 seconds while the playback duration should be 12 seconds.
            frames = min(int(9 / self.dt), full_time_field.shape[1])
            interval_ms = 12000 / frames  # Total playback duration: 12000 ms

            ani = animation.FuncAnimation(fig, update_frame, frames=frames, interval=interval_ms, blit=True)
            savepath = self.project_path + '/waveform/wavefield_illu.gif'
            ani.save(savepath, writer="pillow")
            print(f"GIF animation saved in {savepath}")
            plt.close(fig)

        return full_time_field, rec_pf, rec_time_field, t