import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import integrate
import sys
import os
sys.path.append("../../")
from uniSI.propagator  import *
from uniSI.model       import *
from uniSI.view        import *
from uniSI.utils       import *
from uniSI.survey      import *
from uniSI.inversion      import *

project_path = "./data/uniformgrid"
if not os.path.exists(os.path.join(project_path,"model")):
    os.makedirs(os.path.join(project_path,"model"))
if not os.path.exists(os.path.join(project_path,"waveform")):
    os.makedirs(os.path.join(project_path,"waveform"))
if not os.path.exists(os.path.join(project_path,"survey")):
    os.makedirs(os.path.join(project_path,"survey"))

device = "cuda:0"         # Specify the GPU device
dtype = torch.float32     # Set data type to 32-bit floating point
ox, oz = 0, 0             # Origin coordinates for x and z directions
nz, nx = 44, 100          # Grid dimensions in z and x directions
zoomfactor = nz/88       # Zoom Factor to baseline 88, 200
dx, dz = 40/zoomfactor, 40/zoomfactor           # Grid spacing in x and z directions
nt, dt = int(8000*zoomfactor), 0.002/zoomfactor     # Time steps and time interval#12000 work
nabc = int(40*zoomfactor)                 # Thickness of the absorbing boundary layer
f0 = 0.1                    # Initial frequency in Hz
free_surface = True       # Enable free surface boundary condition

# Change nz, nz and nabc if dx and dx are changed!


# Load the Marmousi model dataset
marmousi_model = load_marmousi_model(in_dir="./datasets/marmousi2_source")
water_depth = int(12*zoomfactor)  # No water layer in the model, just pre-processing for marmousi2
# Resample the Marmousi model for the defined coordinates
x = np.linspace(5000, 5000 + dx * nx, nx)
z = np.linspace(0, dz * nz, nz)
vel_model = resample_marmousi_model(x, z, marmousi_model)

# Remove the water layer from the velocity model
# Remove water layer and replace with minimum rock velocity
vp_true = vel_model['vp'].T
vs_true = vel_model['vs'].T
rho_true = np.power(vp_true, 0.25) * 310 
np.savez(os.path.join(project_path, "model/true_model.npz"), vs=vs_true)

import scipy.ndimage
vs_init = scipy.ndimage.gaussian_filter(vs_true, sigma=3)
rho_init = scipy.ndimage.gaussian_filter(rho_true, sigma=3)

vs_init[water_depth:2+water_depth, :] = vs_true[water_depth:2+water_depth, :] # Source Depth
rho_init[water_depth:2+water_depth, :] = rho_true[water_depth:2+water_depth, :] # Source Depth
np.savez(os.path.join(project_path, "model/init_model.npz"), vs=vs_init)


# Here initial model keeps the water layer for cross-scenario comparison
vp_true = vp_true[water_depth:, :]  # Remove water layer from velocity model
vs_true = vs_true[water_depth:, :]  # Remove water layer from velocity model
rho_true = rho_true[water_depth:, :]  # Remove water layer from velocity model
vs_init = vs_init[water_depth:, :]  # Remove water layer from velocity model
rho_init = rho_init[water_depth:, :]  # Remove water layer from velocity model
nz, nx = vp_true.shape  # Update grid dimensions after removing water layer 
water_mask = np.zeros((nz, nx), dtype=bool)  

# Initialize the SHModel with parameters and properties
model = SHModel(ox, oz, nx, nz, dx, dz,
                      vs=vs_init, rho=rho_init,  
                      vs_grad=True,    
                      free_surface=free_surface,
                      abc_type="PML",
                      abc_jerjan_alpha=0.007,
                      nabc=nabc,
                      device=device,
                      dtype=dtype)

# Print the model's representation for verification.
print(model.__repr__())

# Plot the primary wave velocity (vp) and density (rho) of the model
model._plot_mu_rho(figsize=(16,5),wspace=0.15,cbar_pad_fraction=0.1, cmap='coolwarm',save_path=os.path.join(project_path,"model/true_vp_rho.png"))

# Define source positions in the model
src_z = np.array([1 for i in range(int(2*zoomfactor), nx-1, int(5*np.ceil(zoomfactor)))])  # Z-coordinates for sources
src_x = np.array([i for i in range(int(2*zoomfactor), nx-1, int(5*np.ceil(zoomfactor)))])  # X-coordinates for sources

t0=2
# Generate wavelet for the source
src_t, src_v = wavelet(nt, dt, f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_2, src_v_2 = wavelet(nt, dt, 2*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_3, src_v_3 = wavelet(nt, dt, 3*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_4, src_v_4 = wavelet(nt, dt, 4*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_5, src_v_5 = wavelet(nt, dt, 5*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_6, src_v_6 = wavelet(nt, dt, 6*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_7, src_v_7 = wavelet(nt, dt, 7*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_8, src_v_8 = wavelet(nt, dt, 8*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_9, src_v_9 = wavelet(nt, dt, 9*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_t_10, src_v_10 = wavelet(nt, dt, 10*f0, amp0=1,t0=t0)  # Create time and wavelet amplitude
src_v = src_v + src_v_2 + src_v_3 + src_v_4 + src_v_5 + src_v_6 + src_v_7 + src_v_8 + src_v_9 + src_v_10
# Normalize the wavelet
src_v = src_v / np.max(np.abs(src_v))  # Normalize wavelet to maximum amplitude
#src_v = src_v * np.hanning(len(src_v))  # Apply Hanning window to wavelet
#src_v = integrate.cumtrapz(src_v, axis=-1, initial=0)  # Integrate wavelet to get velocity

source = Source(nt=nt, dt=dt, f0=f0)  # Initialize source object

# SH wave
mt_sh = np.array([[0, 1, 0],
                  [1, 0, 0],
                  [0, 0, 0]])

# Method 1: Add multiple sources at once (commented out)
# source.add_sources(src_x=src_x, src_z=src_z, src_wavelet=src_v, src_type='mt', src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))

# Method 2: Loop through each source position to add them individually
for i in range(len(src_x)):
    source.add_source(src_x=src_x[i], src_z=src_z[i], src_wavelet=src_v, src_type="mt", src_mt=mt_sh)

# Define receiver positions in the model
rcv_z = np.array([1 for i in range(0, nx, int(1*np.ceil(zoomfactor)))])  # Z-coordinates for receivers
rcv_x = np.array([j for j in range(0, nx, int(1*np.ceil(zoomfactor)))])  # X-coordinates for receivers

receiver = Receiver(nt=nt, dt=dt)  # Initialize receiver object

# Method 1: Add all receivers at once (commented out)
# receiver.add_receivers(rcv_x=rcv_x, rcv_z=rcv_z, rcv_type='pr')

# Method 2: Loop through each receiver position to add them individually
for i in range(len(rcv_x)):
    receiver.add_receiver(rcv_x=rcv_x[i], rcv_z=rcv_z[i], rcv_type="vy")

# Create a survey object using the defined source and receiver
survey = Survey(source=source, receiver=receiver)

# Print a representation of the survey object to check its configuration
print(survey.__repr__())

# Plot the wavelet used in the source
source.plot_wavelet(save_path=os.path.join(project_path, "survey/wavelets.png"))


# Plot the survey configuration over the velocity model
survey.plot(model.vs, cmap='coolwarm', save_path=os.path.join(project_path, "survey/observed_system.png"))

# Initialize the wave propagator using the specified model and survey configuration
F = SHPropagator(model, survey, ifvisualWave=False, projectpath=project_path, device=device)

# Retrieve the damping array from the propagator and plot it to visualize boundary conditions
damp = F.damp
plot_damp(damp, save_path=os.path.join(project_path, "model/boundary_condition.png"))

# Create an instance of SeismicData using the survey object.
d_obs = SeismicData(survey)

# Load observed waveform data from a specified file.
d_obs.load(os.path.join(project_path, "waveform/obs_data.npz"))

# Add noise to the observed data
d_obs.add_noise(noise_level=0.01)

# Print a summary representation of the observed seismic data.
print(d_obs.__repr__())

# Import the L2 misfit function for waveform inversion.
from uniSI.inversion.misfit import Misfit_waveform_L2,Misfit_global_correlation, Misfit_traveltime

torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

iteration = 280 # 300

# Initialize the optimizer (Adam) for model parameters with a learning rate.
optimizer = torch.optim.Adam(model.parameters(), lr=0.5) #

# Set up a learning rate scheduler to adjust the learning rate over time.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.75, last_epoch=-1)

# Configure the misfit function to compute the loss based on the observed data.
loss_fn = Misfit_waveform_L2(dt=dt)
# loss_fn = Misfit_global_correlation(dt=dt)
# loss_fn = Misfit_traveltime(dt=dt)


grad_mask = np.ones_like(vs_init)
# Free Surface
#grad_mask[:2,:] = 0
grad_mask[:2,:] = 0
grad_mask[-1,:] = 0
grad_mask[:,0] = 0
grad_mask[:,-1] = 0
gradient_processor = GradProcessor(grad_mask=grad_mask,forw_illumination=True) 


# Print inversion save path
print(os.path.join(project_path, "inversion"))

# Initialize the acoustic full waveform inversion (Inversion) object.
inversion = SHInversion(propagator=F,
                  model=model,
                  optimizer=optimizer,
                  scheduler=scheduler,
                  loss_fn=loss_fn,
                  obs_data=d_obs,
                  gradient_processor=gradient_processor,
                  waveform_normalize=True, 
                  cache_result=True,  
                  save_fig_epoch=10,  
                  save_fig_path=os.path.join(project_path, "inversion"),
                  min_improvement=1e-7)

# Run the forward modeling for the specified number of iterations.
inversion.forward(iteration=iteration, batch_size=None, checkpoint_segments=1)


# Retrieve the inversion results: updated velocity and loss values.
iter_vs = inversion.iter_vs
iter_rho = inversion.iter_rho
iter_loss = inversion.iter_loss

# Save the iteration results to files for later analysis.
save_dir = os.path.join(project_path, "inversion")
os.makedirs(save_dir, exist_ok=True)  
np.savez(os.path.join(project_path, "inversion/iter_vs.npz"), data=np.array(iter_vs))
np.savez(os.path.join(project_path, "inversion/iter_loss.npz"), data=np.array(iter_loss))
np.savez(os.path.join(project_path, "inversion/water_depth.npz"), data=water_depth)


from uniSI.inversion.misfit import Misfit_waveform_L2,Misfit_global_correlation

# plot the misfit
plt.figure(figsize=(8,6))
plt.plot(iter_loss,c='k')
plt.xlabel("Iterations", fontsize=12)
plt.ylabel("L2-norm Misfits", fontsize=12)
plt.tick_params(labelsize=12)
plt.savefig(os.path.join(project_path,"inversion/misfit.png"),bbox_inches='tight',dpi=100)
plt.show()