"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import numpy as np
from typing import Tuple,Dict, List, Callable, Optional
import matplotlib.pyplot as plt

@torch.jit.script
def pad_torchSingle(v: torch.Tensor, pml: int, nz: int, nx: int, ns: int, device: torch.device = torch.device("cpu")) -> torch.Tensor:
    nz_pml = nz + 2 * pml
    nx_pml = nx + 2 * pml
    cc = torch.zeros((nz_pml, nx_pml), device=device)
    
    # Copy the original tensor to the appropriate position
    cc[pml:nz_pml - pml, pml:nx_pml - pml] = v

    # Handle the top boundary
    cc[:pml, pml:pml + nx] = cc[pml, pml:pml + nx].expand(pml, -1)
    
    # Handle the bottom boundary
    cc[nz_pml - pml:nz_pml, pml:pml + nx] = cc[nz_pml - pml - 1, pml:pml + nx].expand(pml, -1)

    # Handle the left boundary
    cc[:, :pml] = cc[:, [pml]].expand(-1, pml)

    # Handle the right boundary
    cc[:, nx_pml - pml:nx_pml] = cc[:, [nx_pml - pml - 1]].expand(-1, pml)

    return cc




@torch.jit.script
def check_derivatives(v, sxy, syz, dx, dz, it):
    if it % 100 == 0:

        dx_sxy = (sxy[:, :, 1:] - sxy[:, :, :-1]) / dx
        print(f"Max x-derivative: {dx_sxy.abs().max().item()}")
        
        dz_syz = (syz[:, 1:, :] - syz[:, :-1, :]) / dz
        print(f"Max z-derivative: {dz_syz.abs().max().item()}")



@torch.jit.script
def step_forward_sh(nx: int, nz: int, dx: float, dz: float, dt: float,
                 nabc: int, snapshot_interval: int, free_surface: bool,
                 src_x: torch.Tensor, src_z: torch.Tensor, src_n: int, src_v: torch.Tensor,
                 rcv_x: torch.Tensor, rcv_z: torch.Tensor, rcv_n: int,
                 kappa1: torch.Tensor, alpha1: torch.Tensor, kappa2: torch.Tensor, alpha2: torch.Tensor,
                 kappa3: torch.Tensor, c1_staggered: float, c2_staggered: float,
                 v: torch.Tensor, sxy: torch.Tensor, syz: torch.Tensor, rho_pad: torch.Tensor, mu_pad: torch.Tensor,
                 device: torch.device = torch.device("cpu"), 
                 dtype: torch.dtype = torch.float32,
                 debug: bool = False  
                 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Forward simulation with one time step for SH wave equation
    v: y-direction velocity
    sxy: shear stress in xy plane
    syz: shear stress in yz plane
    """
    v = v.clone()
    sxy = sxy.clone()
    syz = syz.clone()
    
    nt = src_v.shape[-1]
    free_surface_start = nabc if free_surface else 2
    nx_pml = nx + 2 * nabc
    nz_pml = nz + 2 * nabc

    num_snapshots = nt // snapshot_interval + 1
    snapshots = torch.zeros((num_snapshots, src_n, nz, nx), dtype=dtype, device=device)
    snapshot_idx = 0

    rcv_v = torch.zeros((src_n, nt, rcv_n), dtype=dtype, device=device)
    rcv_sxy = torch.zeros((src_n, nt, rcv_n), dtype=dtype, device=device)
    rcv_syz = torch.zeros((src_n, nt, rcv_n), dtype=dtype, device=device)

    forward_v = torch.zeros((nz, nx), dtype=dtype, device=device)
    forward_sxy = torch.zeros((nz, nx), dtype=dtype, device=device)
    forward_syz = torch.zeros((nz, nx), dtype=dtype, device=device)

    for it in range(nt):

        z_start = free_surface_start
        z_end = nz_pml - 1
        x_start = 1
        x_end = nx_pml - 1

        dvdx = (v[:, z_start:z_end, x_start+1:x_end+1] - 
                v[:, z_start:z_end, x_start:x_end]) / dx  
        
        dvdz = (v[:, z_start+1:z_end+1, x_start:x_end] - 
                v[:, z_start:z_end, x_start:x_end]) / dz 

        sxy[:, z_start:z_end, x_start:x_end] = (
            (1.0 - kappa2[z_start:z_end, x_start:x_end]) * 
            sxy[:, z_start:z_end, x_start:x_end] +
            mu_pad[z_start:z_end, x_start:x_end] * dt * dvdx
        )

        syz[:, z_start:z_end, x_start:x_end] = (
            (1.0 - kappa3[z_start:z_end, x_start:x_end]) * 
            syz[:, z_start:z_end, x_start:x_end] +
            mu_pad[z_start:z_end, x_start:x_end] * dt * dvdz
        )

        src_update = dt * (src_v[it] if len(src_v.shape) == 1 else src_v[:, it])
        v[torch.arange(src_n), src_z, src_x] += src_update / (dx * dz * rho_pad[src_z, src_x])

        if free_surface:
            v[:, free_surface_start, :] = v[:, free_surface_start+1, :]
            syz[:, free_surface_start, :] = 0.0

        dsxy_dx = (sxy[:, z_start:z_end, x_start:x_end] - 
                   sxy[:, z_start:z_end, x_start-1:x_end-1]) / dx 
        
        dsyz_dz = (syz[:, z_start:z_end, x_start:x_end] - 
                   syz[:, z_start-1:z_end-1, x_start:x_end]) / dz  

        v[:, z_start:z_end, x_start:x_end] = (
            (1.0 - kappa1[z_start:z_end, x_start:x_end]) * 
            v[:, z_start:z_end, x_start:x_end] +
            dt / rho_pad[z_start:z_end, x_start:x_end] * (dsxy_dx + dsyz_dz)
        )


        rcv_v[:, it] = v[:, rcv_z, rcv_x]
        rcv_sxy[:, it] = sxy[:, rcv_z, rcv_x]
        rcv_syz[:, it] = syz[:, rcv_z, rcv_x]

        
        if it % snapshot_interval == 0:
            snapshots[snapshot_idx] = v[:, nabc:nabc+nz, nabc:nabc+nx].clone()
            snapshot_idx += 1
            
        forward_v += torch.sum(v * v, dim=0)[nabc:nabc+nz, nabc:nabc+nx].detach()
        forward_sxy += torch.sum(sxy * sxy, dim=0)[nabc:nabc+nz, nabc:nabc+nx].detach()
        forward_syz += torch.sum(syz * syz, dim=0)[nabc:nabc+nz, nabc:nabc+nx].detach()


    return v, sxy, syz, rcv_v, rcv_sxy, rcv_syz, forward_v, forward_sxy, forward_syz, snapshots


def forward_kernel(nx: int, nz: int, dx: float, dz: float, nt: int, dt: float,
                  nabc: int, free_surface: bool, ifvisualWave: bool, projectpath: str,
                  src_x: Tensor, src_z: Tensor, src_n: int, src_v: Tensor,
                  rcv_x: Tensor, rcv_z: Tensor, rcv_n: int,
                  damp: Tensor, mu: Tensor, rho: Tensor,
                  checkpoint_segments: int = 1,
                  device: torch.device = torch.device("cpu"),
                  dtype: torch.dtype = torch.float32,
                  debug: bool = False,checkerboard=False) -> Dict[str, Tensor]:
    """Forward simulation of SH Wave Equation

    Parameters:
    --------------
        nx, nz (int)                    : Grid points in x, z directions
        dx, dz (float)                  : Grid spacing in x, z directions
        nt (int)                        : Number of time steps
        dt (float)                      : Time step size
        nabc (int)                      : PML thickness
        free_surface (bool)             : Free surface condition flag
        src_x, src_z (Tensor)          : Source positions
        src_n (int)                     : Number of sources
        src_v (Tensor)                  : Source wavelets
        rcv_x, rcv_z (Tensor)          : Receiver positions
        rcv_n (int)                     : Number of receivers
        damp (Tensor)                   : PML damping profile
        mu (Tensor)                     : Shear modulus
        rho (Tensor)                    : Density
        checkpoint_segments (int)       : Number of checkpoint segments
        device (torch.device)           : Computation device
        dtype (torch.dtype)             : Data type
    
    Returns:
    ---------------
        record_waveform (dict)          : Dictionary containing:
            - v                         : y-direction velocity
            - sxy, syz                  : Shear stresses
            - forward_wavefield_v       : Forward wavefield (velocity)
            - forward_wavefield_sxy     : Forward wavefield (xy stress)
            - forward_wavefield_syz     : Forward wavefield (yz stress)
    """
    ###################################################################################
    
    # Pad normalized material properties
    mu_pad = pad_torchSingle(mu, nabc, nz, nx, src_n, device=device)
    rho_pad = pad_torchSingle(rho, nabc, nz, nx, src_n, device=device)
    
    free_surface_start = nabc if free_surface else 1
    nx_pml = nx + 2 * nabc
    nz_pml = nz + 2 * nabc
    
    # Adjust source and receiver positions for PML
    src_x = src_x + nabc
    src_z = src_z + nabc
    rcv_x = rcv_x + nabc
    rcv_z = rcv_z + nabc
    
    # Initialize wavefields
    v = torch.zeros((src_n, nz_pml, nx_pml), dtype=dtype, device=device)      # y-direction velocity
    sxy = torch.zeros((src_n, nz_pml, nx_pml), dtype=dtype, device=device)    # xy shear stress
    syz = torch.zeros((src_n, nz_pml, nx_pml), dtype=dtype, device=device)    # yz shear stress

    # Initialize recorded waveforms
    rcv_v = torch.zeros((src_n, nt, rcv_n), dtype=dtype, device=device)
    rcv_sxy = torch.zeros((src_n, nt, rcv_n), dtype=dtype, device=device)
    rcv_syz = torch.zeros((src_n, nt, rcv_n), dtype=dtype, device=device)
    forward_wavefield_v = torch.zeros((nz, nx), dtype=dtype, device=device)
    forward_wavefield_sxy = torch.zeros((nz, nx), dtype=dtype, device=device)
    forward_wavefield_syz = torch.zeros((nz, nx), dtype=dtype, device=device)

    # Coefficients for staggered grid
    c1_staggered = 9.0 / 8.0
    c2_staggered = -1.0 / 24.0
    
    # Parameters for waveform simulation

    alpha1 = dt / rho_pad  
    kappa1 = damp * dt
    

    alpha2 = mu_pad * dt  
    

    kappa2 = torch.zeros_like(damp, device=device)
    kappa2[:, 1:nx_pml - 2] = 0.5 * (damp[:, 1:nx_pml - 2] + damp[:, 2:nx_pml - 1]) * dt
    
    kappa3 = torch.zeros_like(damp, device=device)
    kappa3[free_surface_start:nz_pml - 2, :] = 0.5 * (damp[free_surface_start:nz_pml - 2, :] + 
                                                      damp[free_surface_start + 1:nz_pml - 1, :]) * dt
    

    if debug:
        print("\n=== Grid and Time Parameters ===")
        print(f"Grid size: nx={nx}, nz={nz}")
        print(f"Grid spacing: dx={dx}, dz={dz}")
        print(f"Time steps: nt={nt}")
        print(f"Time step size: dt={dt}")
        print(f"Total simulation time: {dt*nt} seconds")

        print("\n=== PML Parameters ===")
        print(f"PML thickness: {nabc}")
        print(f"Grid size with PML: nx_pml={nx_pml}, nz_pml={nz_pml}")
        print(f"Free surface: {free_surface}")
        print(f"Damping profile range: [{damp.min().item()}, {damp.max().item()}]")

        print("\n=== Source Parameters ===")
        print(f"Number of sources: {src_n}")
        print(f"Source x positions: {src_x.cpu().numpy()}")
        print(f"Source z positions: {src_z.cpu().numpy()}")
        print(f"Source wavelet shape: {src_v.shape}")
        print(f"Source wavelet range: [{src_v.min().item()}, {src_v.max().item()}]")

        print("\n=== Material Parameters ===")
        print("Before padding:")
        print(f"Mu range: [{mu.min().item()}, {mu.max().item()}]")
        print(f"Rho range: [{rho.min().item()}, {rho.max().item()}]")
        print(f"Wave speed range: [{torch.sqrt(mu/rho).min().item()}, {torch.sqrt(mu/rho).max().item()}]")
        
        print("\nAfter padding:")
        print(f"Mu_pad range: [{mu_pad.min().item()}, {mu_pad.max().item()}]")
        print(f"Rho_pad range: [{rho_pad.min().item()}, {rho_pad.max().item()}]")
        v_max = torch.sqrt(mu_pad/rho_pad).max().item()
        print(f"Wave speed range: [{torch.sqrt(mu_pad/rho_pad).min().item()}, {v_max}]")
        
        cfl = v_max * dt / min(dx, dz)
        print(f"CFL number: {cfl}")
    
        print("\n=== Computation Coefficients ===")
        print(f"Staggered grid coefficients: c1={c1_staggered}, c2={c2_staggered}")
        print("Alpha1 (dt/rho) range:")
        print(f"[{alpha1.min().item()}, {alpha1.max().item()}]")
        print("Alpha2 (mu*dt) range:")
        print(f"[{alpha2.min().item()}, {alpha2.max().item()}]")



    if ifvisualWave:
        all_snapshots = []
    snapshot_interval = 20
    
    # Time stepping with checkpointing
    k = 0
    for i, chunk in enumerate(torch.chunk(src_v, checkpoint_segments, dim=-1)):
        v, sxy, syz, rcv_v_temp, rcv_sxy_temp, rcv_syz_temp, \
        forward_wavefield_v_temp, forward_wavefield_sxy_temp, forward_wavefield_syz_temp, chunk_snapshots = \
            step_forward_sh( ## Commented out the checkpoint function
                nx, nz, dx, dz, dt,
                nabc, snapshot_interval, free_surface,
                src_x, src_z, src_n, chunk,
                rcv_x, rcv_z, rcv_n,
                kappa1, alpha1, kappa2, alpha2, kappa3, c1_staggered, c2_staggered,
                v, sxy, syz, rho_pad, mu_pad,
                device, dtype,
                debug
            )

        # Save recorded waveforms
        rcv_v[:, k:k + chunk.shape[-1]] = rcv_v_temp
        rcv_sxy[:, k:k + chunk.shape[-1]] = rcv_sxy_temp
        rcv_syz[:, k:k + chunk.shape[-1]] = rcv_syz_temp

        # Accumulate forward wavefields
        forward_wavefield_v += forward_wavefield_v_temp.detach()
        forward_wavefield_sxy += forward_wavefield_sxy_temp.detach()
        forward_wavefield_syz += forward_wavefield_syz_temp.detach()
            
        k = k + chunk.shape[-1]
        if ifvisualWave:
            all_snapshots.extend(chunk_snapshots)
    
    # Prepare return dictionary
    record_waveform = {
        "vy": rcv_v,  
        "sxy": rcv_sxy , 
        "syz": rcv_syz ,
        "forward_wavefield_v": forward_wavefield_v ,
        "forward_wavefield_sxy": forward_wavefield_sxy ,
        "forward_wavefield_syz": forward_wavefield_syz 
    }

    if ifvisualWave:
        # Visualize wavefield propagation
        if free_surface:
            savepath = projectpath + '/waveform/sh_wavefield_freesurface.gif'
        else:
            savepath = projectpath + '/waveform/sh_wavefield_Nofreesurface.gif'    

        
        visualize_wavefield_propagation(
            snapshots=chunk_snapshots,
            v_model=torch.sqrt(mu/rho).cpu().detach().numpy(),
            extent=[0, nx, 0, nz],
            dt=dt,
            snapshot_interval=snapshot_interval,
            save_path=savepath,
            src_loc=torch.stack([src_x-nabc, src_z-nabc], dim=1).cpu().numpy(),
            rcv_loc=torch.stack([rcv_x-nabc, rcv_z-nabc], dim=1).cpu().numpy(),
            checkerboard=checkerboard
        )
        print(f"Saving wavefield animation to {savepath}")
    
    return record_waveform


def visualize_wavefield_propagation(snapshots: torch.Tensor, v_model: np.ndarray, 
                                  extent: List[float], dt: float, snapshot_interval: int,
                                  save_path: str = 'wavefield_animation.gif',
                                  src_loc: np.ndarray = None, rcv_loc: np.ndarray = None,checkerboard=False):
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    
    source_visulization = 5
    
    ny, nx = v_model.shape
    fig_width = 8  
    if checkerboard:
        fig_height = (ny/nx) * fig_width * 3
    else:
        fig_height = (ny/nx) * fig_width * 2.2
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, fig_height))
    if checkerboard:
        im1 = ax1.imshow(v_model, extent=extent, cmap='coolwarm', aspect='auto')
    else:
        im1 = ax1.imshow(v_model, extent=extent, cmap='coolwarm', aspect='equal')
    ax1.set_title('Velocity Model')
    cbar1 = plt.colorbar(im1, ax=ax1, label='Velocity (m/s)')

    vmax = snapshots.max().item()
    if checkerboard:
        vmax = vmax*0.4
    else:
        vmax = vmax*0.6
    vmin = -vmax
    if checkerboard:
        im2 = ax2.imshow(snapshots[0, 0].detach().cpu(), extent=extent, 
                     cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')
    else:
        im2 = ax2.imshow(snapshots[0, 0].detach().cpu(), extent=extent, 
                     cmap='seismic', vmin=vmin, vmax=vmax, aspect='equal')
    ax2.set_title('Vy')
    cbar2 = plt.colorbar(im2, ax=ax2, label='Amplitude')
    
    if src_loc is not None:
        ax1.scatter(src_loc[source_visulization, 0], extent[3]-src_loc[source_visulization, 1], 
                   c='r', marker='*', s=100, label='Source')
        ax2.scatter(src_loc[source_visulization, 0], extent[3]-src_loc[source_visulization, 1], 
                   c='r', marker='*', s=100)
    
    ax1.legend(loc='lower left')
    
    
    time_text = ax2.text(0.98, 0.02, '', transform=ax2.transAxes,
                        horizontalalignment='right',
                        verticalalignment='bottom')
    
    if checkerboard:
        plt.subplots_adjust(hspace=0.9) 
    else:
        plt.subplots_adjust(hspace=0.3)
    
    def update(frame):
        im2.set_array(snapshots[frame, source_visulization].detach().cpu())
        time_text.set_text(f'Time: {frame * snapshot_interval * dt:.3f}s')
        return im2, time_text
    
    anim = animation.FuncAnimation(fig, update, frames=len(snapshots),
                                 interval=len(snapshots)/15, blit=True)
    anim.save(save_path, writer='pillow')
    plt.close()