import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import integrate
import sys
import os
sys.path.append("../../")
from uniSI.propagator  import *
from uniSI.model       import *
from uniSI.view        import *
from uniSI.utils       import *
from uniSI.survey      import *
from uniSI.inversion      import *

project_path = "./data/uniformgrid"
if not os.path.exists(os.path.join(project_path,"model")):
    os.makedirs(os.path.join(project_path,"model"))
if not os.path.exists(os.path.join(project_path,"waveform")):
    os.makedirs(os.path.join(project_path,"waveform"))
if not os.path.exists(os.path.join(project_path,"survey")):
    os.makedirs(os.path.join(project_path,"survey"))

device = "cuda:0"         # Specify the GPU device
dtype = torch.float32     # Set data type to 32-bit floating point
ox, oz = 0, 0             # Origin coordinates for x and z directions
nz, nx = 44, 100          # Grid dimensions in z and x directions
zoomfactor = nz/88       # Zoom Factor to baseline 88, 200
dx, dz = 40/zoomfactor, 40/zoomfactor           # Grid spacing in x and z directions
nt, dt = int(3000*zoomfactor), 0.003/zoomfactor     # Time steps and time interval
nabc = int(30*zoomfactor)                 # Thickness of the absorbing boundary layer
f0 = 3                    # Initial frequency in Hz
free_surface = True       # Enable free surface boundary condition

# Change nz, nz and nabc if dx and dx are changed!

## Generate the water layer matrix
water_layer = int(12*zoomfactor) 

# Load the Marmousi model dataset from the specified directory.
marmousi_model = load_marmousi_model(in_dir="./datasets/marmousi2_source")

# Create coordinate arrays for x and z based on the grid size.
x = np.linspace(5000, 5000 + dx * nx, nx)
z = np.linspace(0, dz * nz, nz)
true_model   = resample_marmousi_model(x, z, marmousi_model)
smooth_model = get_smooth_marmousi_model_downsample(true_model, gaussian_kernel=3,rcv_depth=2,mask_extra_detph=int(2*zoomfactor),down_sample=1)

# Extract true model properties for comparison.
vp_true = true_model['vp'].T  # Transpose for consistency
rho_true = np.power(vp_true, 0.25) * 310  # Calculate true density

water_mask = np.zeros_like(vp_true, dtype=bool)  # Initialize water mask
water_mask[:water_layer,:] = 1
rho_true[water_mask] = 1000  # Set water density

# Initialize primary wave velocity (vp) and density (rho) for the model.
vp_init = smooth_model['vp'].T  # Transpose to match dimensions
rho_init = np.power(vp_init, 0.25) * 310  # Calculate density based on vp

vp_init[water_mask] = vp_true[water_mask]
rho_init[water_mask] = rho_true[water_mask]

# Initialize model
model = AcousticModel(ox, oz, nx, nz, dx, dz,
                      vp_init, rho_init,
                      vp_grad=True,
                      free_surface=free_surface,
                      abc_type="PML",
                      abc_jerjan_alpha=0.007,
                      nabc=nabc,
                      device=device,
                      dtype=dtype)

# Save the initialized model to a file for later use.
model.save(os.path.join(project_path, "model/init_model.npz"))

np.savez(os.path.join(project_path, "model/true_model.npz"), vp=true_model['vp'].T)

# Print the model's representation for verification.
print(model.__repr__())


# Define source positions in the model
src_z = np.array([2 for i in range(int(2*zoomfactor), nx-1, int(5*np.ceil(zoomfactor)))])  # Z-coordinates for sources
src_x = np.array([i for i in range(int(2*zoomfactor), nx-1, int(5*np.ceil(zoomfactor)))])  # X-coordinates for sources

# Generate wavelet for the source
src_t, src_v = wavelet(nt, dt, f0, amp0=1)  # Create time and wavelet amplitude
src_v = integrate.cumtrapz(src_v, axis=-1, initial=0)  # Integrate wavelet to get velocity

source = Source(nt=nt, dt=dt, f0=f0)  # Initialize source object

# Method 1: Add multiple sources at once (commented out)
# source.add_sources(src_x=src_x, src_z=src_z, src_wavelet=src_v, src_type='mt', src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))

# Method 2: Loop through each source position to add them individually
for i in range(len(src_x)):
    source.add_source(src_x=src_x[i], src_z=src_z[i], src_wavelet=src_v, src_type="mt", src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))


# Define receiver positions in the model
rcv_z = np.array([2 for i in range(0, nx, int(1*np.ceil(zoomfactor)))])  # Z-coordinates for receivers
rcv_x = np.array([j for j in range(0, nx, int(1*np.ceil(zoomfactor)))])  # X-coordinates for receivers

receiver = Receiver(nt=nt, dt=dt)  # Initialize receiver object

# Method 1: Add all receivers at once (commented out)
# receiver.add_receivers(rcv_x=rcv_x, rcv_z=rcv_z, rcv_type='pr')

# Method 2: Loop through each receiver position to add them individually
for i in range(len(rcv_x)):
    receiver.add_receiver(rcv_x=rcv_x[i], rcv_z=rcv_z[i], rcv_type="pr")

# Create a survey object using the defined source and receiver
survey = Survey(source=source, receiver=receiver)

# Print a representation of the survey object to check its configuration
print(survey.__repr__())

# Plot the wavelet used in the source
source.plot_wavelet(save_path=os.path.join(project_path, "survey/wavelets.png"))

# Initialize the wave propagator using the specified model and survey configuration
F = AcousticPropagator(model, survey, device=device)

# Retrieve the damping array from the propagator and plot it to visualize boundary conditions
damp = F.damp
plot_damp(damp, save_path=os.path.join(project_path, "model/boundary_condition.png"))

# Create an instance of SeismicData using the survey object.
d_obs = SeismicData(survey)

# Load observed waveform data from a specified file.
d_obs.load(os.path.join(project_path, "waveform/obs_data.npz"))

# Add noise to the observed data
d_obs.add_noise(noise_level=0.01)

# Print a summary representation of the observed seismic data.
print(d_obs.__repr__())

# Import the L2 misfit function for waveform inversion.
from uniSI.inversion.misfit import Misfit_waveform_L2,Misfit_global_correlation, Misfit_traveltime

torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

iteration = 1000

# Initialize the optimizer (Adam) for model parameters with a learning rate.
optimizer = torch.optim.Adam(model.parameters(), lr=5)

# Set up a learning rate scheduler to adjust the learning rate over time.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.75, last_epoch=-1)

# Configure the misfit function to compute the loss based on the observed data.
loss_fn = Misfit_waveform_L2(dt=dt)
# loss_fn = Misfit_global_correlation(dt=dt)
# loss_fn = Misfit_traveltime(dt=dt)


water_layer_matrix = np.zeros_like(vp_init)
water_layer_matrix[:water_layer,:] = 1
grad_mask = 1.0 - water_layer_matrix
gradient_processor = GradProcessor(grad_mask=grad_mask,grad_mute=water_layer)

# Print inversion save path
print(os.path.join(project_path, "inversion"))

# Initialize the acoustic full waveform inversion (Inversion) object.
inversion = AcousticInversion(propagator=F,
                  model=model,
                  optimizer=optimizer,
                  scheduler=scheduler,
                  loss_fn=loss_fn,
                  obs_data=d_obs,
                  gradient_processor=gradient_processor,
                  waveform_normalize=True, 
                  cache_result=True,  
                  save_fig_epoch=5,  
                  save_fig_path=os.path.join(project_path, "inversion"),
                  min_improvement=1e-7)

# Run the forward modeling for the specified number of iterations.
inversion.forward(iteration=iteration, batch_size=None, checkpoint_segments=1)

# Retrieve the inversion results: updated velocity and loss values.
iter_vp = inversion.iter_vp
iter_rho = inversion.iter_rho
iter_loss = inversion.iter_loss

# Save the iteration results to files for later analysis.
save_dir = os.path.join(project_path, "inversion")
os.makedirs(save_dir, exist_ok=True)  
np.savez(os.path.join(project_path, "inversion/iter_vp.npz"), data=np.array(iter_vp))
np.savez(os.path.join(project_path, "inversion/iter_loss.npz"), data=np.array(iter_loss))

from uniSI.inversion.misfit import Misfit_waveform_L2,Misfit_global_correlation

# plot the misfit
plt.figure(figsize=(8,6))
plt.plot(iter_loss,c='k')
plt.xlabel("Iterations", fontsize=12)
plt.ylabel("L2-norm Misfits", fontsize=12)
plt.tick_params(labelsize=12)
plt.savefig(os.path.join(project_path,"inversion/misfit.png"),bbox_inches='tight',dpi=100)
plt.show()

# plot the initial model and inverted resutls
plt.figure(figsize=(12,8))
plt.subplot(121)
plt.imshow(vp_init,cmap='coolwarm')
plt.subplot(122)
plt.imshow(iter_vp[-1],cmap='coolwarm')
plt.savefig(os.path.join(project_path,"inversion/inverted_vp.png"),bbox_inches='tight',dpi=100)
plt.show()

# plot the initial model and inverted resutls
plt.figure(figsize=(12,8))
plt.subplot(121)
plt.imshow(rho_init,cmap='coolwarm')
plt.subplot(122)
plt.imshow(iter_rho[-1],cmap='coolwarm')
plt.savefig(os.path.join(project_path,"inversion/inverted_rho.png"),bbox_inches='tight',dpi=100)
plt.show()