import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import integrate
import sys
import os
sys.path.append("../../")
from uniSI.propagator  import *
from uniSI.model       import *
from uniSI.view        import *
from uniSI.utils       import *
from uniSI.survey      import *
from uniSI.inversion         import *
try: from torch import sparse; sparse.spsolve; print("✅ torch.sparse.spsolve available")
except (ImportError, AttributeError): print("❌ torch.sparse.spsolve not available"); exit(1)

project_path = "./data/uniformgrid"
if not os.path.exists(os.path.join(project_path,"model")):
    os.makedirs(os.path.join(project_path,"model"))
if not os.path.exists(os.path.join(project_path,"waveform")):
    os.makedirs(os.path.join(project_path,"waveform"))
if not os.path.exists(os.path.join(project_path,"survey")):
    os.makedirs(os.path.join(project_path,"survey"))


nx_org = 50
nz_org = 30
dx = 60.0  
dz = dx
tmax = 4
twrap = tmax//3
alpha = 180

nabc = 10
abc_type = 'Neumann'
device = 'cuda:0'
dtype = torch.float64
ox = 0.0
oz = 0.0
f0 = 2.0
wref = f0
fmax = f0 * 3
df = 1/tmax
nf = int((fmax-df)/df)
t0 = 0.3
nt = 2 * nf + 1  

water_layer=0

def empirical_Q(vp: np.ndarray) -> np.ndarray:
    """Build empirical Q model based on P-wave velocity
    Args:
        vp : P-wave velocity model (m/s)
    Returns:
        Q : Quality factor model
    """
    Q = np.zeros_like(vp)
    
    # Formula 1: Water body region (extremely high Q when vp < 1500 m/s)
    water_mask = (vp < 1510)
    Q[water_mask] = 1e6  # Approximately no attenuation
    
    # Formula 2: Sedimentary rock region (1500 <= vp < 3000 m/s)
    sed_mask = (~water_mask) & (vp < 3000)
    Q[sed_mask] = 0.015 * vp[sed_mask]  # Linear relationship: Q = 0.015*vp
    
    # Formula 3: Basement rock (vp >= 3000 m/s)
    rock_mask = (vp >= 3000)
    Q[rock_mask] = 0.25 * vp[rock_mask]**0.8  # Non-linear relationship
    
    # Set Q lower bound (prevent too small values)
    Q[Q < 10] = 10

    return Q

def empirical_Qs_from_Qp(Qp: np.ndarray) -> np.ndarray:
    vs_vp_ratio = 1.73
    Qs = Qp / (vs_vp_ratio**2)
    return Qs

def empirical_vs_from_vp(vp: np.ndarray) -> np.ndarray:
    vp_vs_ratio = 1.73
    vs = vp / vp_vs_ratio
    return vs

vp_true = layerVp(nx_org, nz_org)

# Calculate density (rho) based on velocity
rho_true = np.power(vp_true, 0.25) * 310
vs_true = empirical_vs_from_vp(vp_true)
Qp_true = empirical_Q(vp_true)  # Empirical Q model
Q_true = empirical_Qs_from_Qp(Qp_true)  # Convert to Qs model

# Experiment: vp is inveted but not exactly same as vp_true
# Q is to be inveted
import scipy.ndimage
vp_init = scipy.ndimage.gaussian_filter(vp_true, sigma=0.4)
#vp_init = vp_true
rho_init = np.power(vp_init, 0.25) * 310
vs_init = empirical_vs_from_vp(vp_init)
Q_init = scipy.ndimage.gaussian_filter(Q_true, sigma=10)
#Q_init = empirical_Q(layerVp_init(nx_org, nz_org))

import matplotlib.pyplot as plt
from scipy import fft
import numpy as np
import torch

def ricker_spectrum(f, f0):
    """
    Returns the unnormalized amplitude spectrum of a Ricker wave:
        A(f) = (f/f0)^2 * exp[-(f/f0)^2]
    
    Parameters:
        f  : Frequency array
        f0 : Dominant frequency
        
    Returns:
        Amplitude spectrum
    """
    return (f / f0) ** 2 * np.exp(-(f / f0) ** 2)

def ricker_spectrum_with_delay(f, f0, t0):
    """
    Returns the frequency spectrum of a Ricker wave with time delay t0,
    including amplitude and phase.
    
    Parameters:
        f  : Frequency array
        f0 : Dominant frequency
        t0 : Time delay (seconds)
    
    Returns:
        Complex spectrum with amplitude and phase
    """
    amplitude = ricker_spectrum(f, f0)
    phase = -2 * np.pi * f * t0  # phase delay
    return amplitude * np.exp(1j * phase)

# Create frequency array from df to fmax with step df (df, fmax are assumed defined)
freq = torch.arange(df, fmax + df, df, dtype=torch.float64)
min_freq = freq[0].item()
nf = len(freq)
freq_np = freq.numpy()

# Generate the delayed Ricker source spectrum
src_spectrum = ricker_spectrum_with_delay(freq_np, f0, t0)

# Build the complete spectrum by adding 0 at f=0, length = nf+1
src_spectrum_full = np.concatenate(([0], src_spectrum))

# Inverse FFT to get the time-domain waveform (nt is predefined)
time_waveform = fft.irfft(src_spectrum_full, n=nt)
# Create time axis with total duration tmax (assumed to be defined)
t = np.linspace(0, tmax, nt)

# Plot: left for source spectrum, right for time-domain waveform
fig, axs = plt.subplots(1, 2, figsize=(14, 4))

axs[0].plot(freq_np, np.abs(src_spectrum), label="Amplitude", color='b')
axs[0].set_xlabel('Frequency (Hz)')
axs[0].set_ylabel('Amplitude')
axs[0].set_title('Source Spectrum')
axs[0].legend()
axs[0].grid(True)

axs[1].plot(t, time_waveform, color='r')
axs[1].set_xlabel("Time (s)")
axs[1].set_ylabel("Amplitude")
axs[1].set_title("Source - Time Domain Waveform")
axs[1].grid(True)

plt.tight_layout()
plt.show()

# Define source location (example: near the center of the grid)
Sz = np.array([2 for i in range(int(2), nx_org, int(3))])  # Z-coordinates for sources
Sx = np.array([i for i in range(int(2), nx_org, int(3))])  # X-coordinates for sources

Rz = np.array([2 for i in range(0, nx_org, 1)])  # Z-coordinates for receivers
Rx = np.array([j for j in range(0, nx_org, 1)])  # X-coordinates for receivers

import numpy as np
import matplotlib.pyplot as plt

Q_true_plot = Q_true.copy()
Q_true_plot[Q_true > 1e5] = np.nan
vmax = np.nanmax(1/Q_true_plot)
vmin = np.nanmin(1/Q_true_plot)

Q_init_plot = Q_init.copy()
Q_init_plot[Q_init > 1e5] = np.nan
fig, axs = plt.subplots(1, 2, figsize=(20, 6))

im0 = axs[0].imshow(1/Q_true_plot, cmap='coolwarm', 
                    extent=[0, nx_org * dx, nz_org * dz, 0],
                    vmax=vmax, vmin=vmin)
axs[0].scatter(Rx * dx, Rz * dz, c='black', marker='v', s=100, label='Receivers')
axs[0].scatter(Sx * dx, Sz * dz, c='r', marker='*', s=100, label='Source')
axs[0].set_xlabel('X (m)')
axs[0].set_ylabel('Z (m)')
axs[0].set_title('True 1/Q model')
axs[0].legend(loc='lower right')
axs[0].grid(True)
cbar0 = fig.colorbar(im0, ax=axs[0], label='1/Q')

im1 = axs[1].imshow(1/Q_init_plot, cmap='coolwarm', 
                    extent=[0, nx_org * dx, nz_org * dz, 0],
                    vmax=vmax, vmin=vmin)
axs[1].scatter(Rx * dx, Rz * dz, c='black', marker='v', s=100, label='Receivers')
axs[1].scatter(Sx * dx, Sz * dz, c='r', marker='*', s=100, label='Source')
axs[1].set_xlabel('X (m)')
axs[1].set_ylabel('Z (m)')
axs[1].set_title('Initial 1/Q model')
axs[1].legend(loc='lower right')
axs[1].grid(True)
cbar1 = fig.colorbar(im1, ax=axs[1], label='1/Q')

plt.tight_layout()
plt.show()

import numpy as np
import matplotlib.pyplot as plt

vs_true_plot = vs_true.copy()
vs_true_plot[vs_true > 1e5] = np.nan
vmax = np.nanmax(vs_true_plot)
vmin = np.nanmin(vs_true_plot)

vs_init_plot = vs_init.copy()
vs_init_plot[vs_init > 1e5] = np.nan
fig, axs = plt.subplots(1, 2, figsize=(20, 6))

im0 = axs[0].imshow(vs_true_plot, cmap='coolwarm', 
                    extent=[0, nx_org * dx, nz_org * dz, 0],
                    vmax=vmax, vmin=vmin)
axs[0].scatter(Rx * dx, Rz * dz, c='black', marker='v', s=100, label='Receivers')
axs[0].scatter(Sx * dx, Sz * dz, c='r', marker='*', s=100, label='Source')
axs[0].set_xlabel('X (m)')
axs[0].set_ylabel('Z (m)')
axs[0].set_title('True vs model')
axs[0].legend(loc='lower right')
axs[0].grid(True)
cbar0 = fig.colorbar(im0, ax=axs[0], label='vs')

im1 = axs[1].imshow(vs_init_plot, cmap='coolwarm', 
                    extent=[0, nx_org * dx, nz_org * dz, 0],
                    vmax=vmax, vmin=vmin)
axs[1].scatter(Rx * dx, Rz * dz, c='black', marker='v', s=100, label='Receivers')
axs[1].scatter(Sx * dx, Sz * dz, c='r', marker='*', s=100, label='Source')
axs[1].set_xlabel('X (m)')
axs[1].set_ylabel('Z (m)')
axs[1].set_title('Initial vs model')
axs[1].legend(loc='lower right')
axs[1].grid(True)
cbar1 = fig.colorbar(im1, ax=axs[1], label='vs')

plt.tight_layout()
plt.show()


model = SHModel_Freq(ox=0.0, oz=0.0,
                            nx=nx_org, nz=nz_org,
                            dx=dx, dz=dz, dt=tmax/nt,
                            vs=vs_init, rho=rho_init, Q=Q_init,
                            wref=wref, Q_grad=True,
                            Sx=Sx, Sz=Sz, Rx=Rx, Rz=Rz,fmax=fmax, 
                            L=nabc, sourceSpectrum=src_spectrum, freq_zpad=0, atten_opt='KF',
                            tmax=tmax, twrap=twrap, alpha=alpha,
                            abc_type=abc_type,
                            device='cuda:0', dtype=torch.float64)


np.savez(os.path.join(project_path, "model/true_model.npz"), Q=Q_true)
np.savez(os.path.join(project_path, "model/init_model.npz"), Q=Q_init)


# Create an instance of SeismicData using the survey object.
# Create a SeismicData object to store observed data from the survey
d_obs = SeismicData_Freq(source_num=Rx.shape[0], receiver_num=Rx.shape[0], source_loc = np.array([Sx, Sz]), receiver_loc=np.array([Rx, Rz]), nf=nf, df=df)  

# Load observed waveform data from a specified file.
d_obs.load(os.path.join(project_path, "waveform/obs_data.npz"))

# Add noise to the observed data
d_obs.add_noise(noise_level=0.01)

# Print a summary representation of the observed seismic data.
print(d_obs.__repr__())


# Import the L2 misfit function for waveform inversion.
from uniSI.inversion.misfit import Misfit_Attenuation, Misfit_waveform_L2

torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

iteration = 220 #130

# Initialize the optimizer (Adam) for model parameters with a learning rate.
optimizer = torch.optim.Adam(model.parameters(), lr=6)  #2

# Set up a learning rate scheduler to adjust the learning rate over time.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.7, last_epoch=-1)

# Configure the misfit function to compute the loss based on the observed data.
loss_fn = Misfit_waveform_L2()


water_layer_matrix = np.zeros_like(vs_init)
water_layer_matrix[:water_layer,:] = 1
grad_mask = 1.0 - water_layer_matrix
gradient_processor = GradProcessor(grad_mask=grad_mask)

# Print inversion save path
print(os.path.join(project_path, "inversion"))

inversion = SHInversion_Freq(model=model,
                  optimizer=optimizer,
                  scheduler=scheduler,
                  loss_fn=loss_fn,
                  obs_data=d_obs,
                  gradient_processor=gradient_processor,
                  waveform_normalize=True, 
                  cache_result=True,  
                  save_fig_epoch=5,  
                  save_fig_path=os.path.join(project_path, "inversion"),
                  min_improvement=1e-7,Q_max=Q_true_plot.max(),Q_min=Q_true_plot.min(),grad_processor_depth_weight=False)

# Run the forward modeling for the specified number of iterations.
inversion.forward(iteration=iteration, batch_size=None, checkpoint_segments=1)

# Retrieve the inversion results: updated velocity and loss values.
iter_Q = inversion.iter_Q
iter_loss = inversion.iter_loss

# Save the iteration results to files for later analysis.
save_dir = os.path.join(project_path, "inversion")
os.makedirs(save_dir, exist_ok=True) 
np.savez(os.path.join(project_path, "inversion/iter_Q.npz"), data=np.array(iter_Q))


# plot the misfit
plt.figure(figsize=(8,6))
plt.plot(iter_loss,c='k')
plt.xlabel("Iterations", fontsize=12)
plt.ylabel("L2-norm Misfits", fontsize=12)
plt.tick_params(labelsize=12)
plt.savefig(os.path.join(project_path,"inversion/misfit.png"),bbox_inches='tight',dpi=100)
plt.show()

# plot the initial model and inverted resutls
plt.figure(figsize=(8,4))
vmax=(1/Q_true_plot).max()
vmin=(1/Q_true_plot).min()
plt.imshow(1/iter_Q[-1],cmap='coolwarm',vmax=vmax,vmin=vmin)
plt.savefig(os.path.join(project_path,"inversion/inverted_1_Q.png"),bbox_inches='tight',dpi=100)
plt.show()