"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from typing import Optional,Dict
import numpy as np
import torch
from torch import Tensor
import matplotlib.pyplot as plt
from uniSI.model import AbstractModel
from uniSI.survey import Survey
from uniSI.utils import numpy2tensor
from .boundary_condition import bc_pml,bc_gerjan,bc_sincos
from .SH_kernels import forward_kernel
from typing import Tuple,Dict, List, Callable, Optional

class SHPropagator(torch.nn.Module):
    """Defines the propagator for the SH wave equation (stress-velocity form), 
    solved by the finite difference method in XZ plane.

    Parameters:
    -----------
    model (AbstractModel)   : The model object 
    survey (Survey)         : The survey object
    device (Optional[str])  : Device type, default is 'cpu'
    cpu_num (Optional[int]) : Number of CPU threads, default is 1
    gpu_num (Optional[int]) : Number of GPU devices, default is 1
    dtype (torch.dtype)     : Data type for tensors, default is torch.float32
    """
    def __init__(self,
                 model  : AbstractModel,
                 survey : Survey,
                 ifvisualWave: Optional[bool] = False,
                 projectpath: Optional[str] = None,
                 device : Optional[str] = 'cpu',
                 cpu_num: Optional[int] = 1,
                 gpu_num: Optional[int] = 1,
                 dtype  : torch.dtype = torch.float32
                 ):
        super().__init__()
        
        # Validate model and survey types
        if not isinstance(model, AbstractModel):
            raise ValueError("model is not an instance of AbstractModel")

        if not isinstance(survey, Survey):
            raise ValueError("survey is not an instance of Survey")
        
        # ---------------------------------------------------------------
        # set the model and survey
        # ---------------------------------------------------------------
        self.model          = model
        self.survey         = survey
        self.device         = device
        self.dtype          = dtype
        self.cpu_num        = cpu_num
        self.gpu_num        = gpu_num
        
        # ---------------------------------------------------------------
        # parse parameters for model
        # ---------------------------------------------------------------
        self.ox, self.oz    = model.ox,model.oz
        self.dx, self.dz    = model.dx,model.dz
        self.nx, self.nz    = model.nx,model.nz
        self.nt             = survey.source.nt
        self.dt             = survey.source.dt
        self.f0             = survey.source.f0
        
        # ---------------------------------------------------------------
        # set the boundary: [top, bottom, left, right]
        # ---------------------------------------------------------------
        self.abc_type       = model.abc_type
        self.nabc           = model.nabc
        self.free_surface   = model.free_surface
        self.bcx,self.bcz,self.damp   = None,None,None
        self.boundary_condition()
        
        # ---------------------------------------------------------------
        # parameters for source
        # ---------------------------------------------------------------
        self.source         = self.survey.source
        self.src_loc        = self.source.get_loc()
        self.src_x          = numpy2tensor(self.src_loc[:,0],torch.long).to(self.device)
        self.src_z          = numpy2tensor(self.src_loc[:,1],torch.long).to(self.device)
        self.src_n          = self.source.num
        self.wavelet        = numpy2tensor(self.source.get_wavelet(),self.dtype).to(self.device)
        self.moment_tensor  = numpy2tensor(self.source.get_moment_tensor(),self.dtype).to(self.device)
        
        # ---------------------------------------------------------------
        # parameters for receiver
        # ---------------------------------------------------------------
        self.receiver       = self.survey.receiver
        self.rcv_loc        = self.receiver.get_loc()
        self.rcv_x          = numpy2tensor(self.rcv_loc[:,0],torch.long).to(self.device)
        self.rcv_z          = numpy2tensor(self.rcv_loc[:,1],torch.long).to(self.device)
        self.rcv_n          = self.receiver.num

        self.ifvisualWave   = ifvisualWave
        self.projectpath    = projectpath
        
    def boundary_condition(self, vs_max=None):
        """Set boundary conditions based on the specified ABC type."""
        if self.abc_type.lower() == "pml":
            if vs_max is not None:
                damp = bc_pml(self.nx, self.nz, self.dx, self.dz, pml=self.nabc, 
                             vmax=vs_max, free_surface=self.free_surface)
            else:
                vs_max = self.model.vs.cpu().detach().numpy().max()
                damp = bc_pml(self.nx, self.nz, self.dx, self.dz, pml=self.nabc,
                             vmax=vs_max,
                             free_surface=self.free_surface)
        elif self.abc_type.lower() == 'gerjan':
            damp = bc_gerjan(self.nx, self.nz, self.dx, self.dz, pml=self.nabc, 
                            alpha=self.model.abc_jerjan_alpha,
                            free_surface=self.free_surface)
        else:
            damp = bc_sincos(self.nx, self.nz, self.dx, self.dz, pml=self.nabc,
                            free_surface=self.free_surface)

        self.damp = numpy2tensor(damp, self.dtype).to(self.device) 
    
    def debug_wavefield(self, v, sxy, syz, dx, dz, it):
        if it % 100 == 0:
            print(f"Time step {it}")
            print(f"Max x-derivative: {((sxy[:, :, 1:] - sxy[:, :, :-1]) / dx).abs().max().item()}")
            print(f"Max z-derivative: {((syz[:, 1:, :] - syz[:, :-1, :]) / dz).abs().max().item()}")
            print(f"Velocity range: [{v.min().item():.2e}, {v.max().item():.2e}]")
            print(f"Sxy range: [{sxy.min().item():.2e}, {sxy.max().item():.2e}]")
            print(f"Syz range: [{syz.min().item():.2e}, {syz.max().item():.2e}]")

    def forward(self,
                model: Optional[AbstractModel] = None,
                shot_index: Optional[int] = None,
                checkpoint_segments: int = 1, debug: bool = False, checkerboard: bool = False
                ) -> Dict[str, Tensor]:
        """Forward simulation for selected shots."""
        model = self.model if model is None else model
        model.forward()
        
        src_x = self.src_x[shot_index] if shot_index is not None else self.src_x
        src_z = self.src_z[shot_index] if shot_index is not None else self.src_z
        src_n = len(src_x)
        wavelet = self.wavelet[shot_index] if shot_index is not None else self.wavelet

        
        record_waveform = forward_kernel(
            self.nx, self.nz, self.dx, self.dz, self.nt, self.dt,
            self.nabc, self.free_surface, self.ifvisualWave, self.projectpath,
            src_x, src_z, src_n, wavelet,
            self.rcv_x, self.rcv_z, self.rcv_n,
            self.damp,
            model.mu, model.rho,
            checkpoint_segments=checkpoint_segments,
            device=self.device, 
            dtype=self.dtype,
            debug=debug,checkerboard=checkerboard
        )
        return record_waveform