import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import integrate
import sys
import os
sys.path.append("../../")
from uniSI.propagator  import *
from uniSI.model       import *
from uniSI.view        import *
from uniSI.utils       import *
from uniSI.survey      import *
from uniSI.inversion    import *

project_path = "./data/uniformgrid"
if not os.path.exists(os.path.join(project_path,"model")):
    os.makedirs(os.path.join(project_path,"model"))
if not os.path.exists(os.path.join(project_path,"waveform")):
    os.makedirs(os.path.join(project_path,"waveform"))
if not os.path.exists(os.path.join(project_path,"survey")):
    os.makedirs(os.path.join(project_path,"survey"))
if not os.path.exists(os.path.join(project_path,"inversion")):
    os.makedirs(os.path.join(project_path,"inversion"))

# Device and data type configuration
device = "cuda:0"         # Specify the GPU device to use (e.g., "cuda:0")
dtype = torch.float32     # Set data type to 32-bit floating point precision

# Model geometry and grid parameters
ox, oz = 0, 0             # Origin coordinates in x and z directions
nz, nx = 44, 100          # Number of grid points in z and x directions
zoomfactor = nz / 88      # Zoom factor relative to a baseline (88 grid points in z)

dx, dz = 40 / zoomfactor, 40 / zoomfactor   # Grid spacing in x and z directions (meters)
nt, dt = int(3000 * zoomfactor), 0.002 / zoomfactor  # Number of time steps and time interval (seconds)
nabc = int(30 * zoomfactor)                 # Thickness of the absorbing boundary layer (in grid points)

# Source and boundary condition settings
f0 = 3                   # Initial source frequency in Hz
free_surface = True      # Enable free surface boundary condition

# Load the Marmousi model dataset
marmousi_model = load_marmousi_model(in_dir="./datasets/marmousi2_source")
water_depth = int(12*zoomfactor)  # No water layer in the model, just pre-processing for marmousi2
# Resample the Marmousi model for the defined coordinates
x = np.linspace(5000, 5000 + dx * nx, nx)
z = np.linspace(0, dz * nz, nz)
vel_model = resample_marmousi_model(x, z, marmousi_model)

# Remove the water layer from the velocity model
# Remove water layer and replace with minimum rock velocity
vp_true = vel_model['vp'].T
vs_true = vel_model['vs'].T

# Find minimum non-water velocity values
vp_min = vp_true[water_depth:, :].min() 
vs_min = vs_true[water_depth:, :].min()  

# Replace water layer with minimum values
vp_true[:water_depth, :] = vp_min  
vs_true[:water_depth, :] = vs_min  
water_mask = np.zeros((nz, nx), dtype=bool)  
#water_mask[:water_depth, :] = True  # Set water layer to True

# Calculate density (rho) based on velocity
rho_true = np.power(vp_true, 0.25) * 310  # Gardner-type relation for rocks
rho_true[water_mask] = 1000  # Set density to 1.0 g/cm^3 for water layer
#rho_true = np.ones_like(vp_true) * 2450  

# Gaussian smoothing true model
import scipy.ndimage
vp_init = scipy.ndimage.gaussian_filter(vp_true, sigma=0.8)
rho_init = scipy.ndimage.gaussian_filter(rho_true, sigma=0.8)

# vp_init = vp_true
# rho_init = rho_true

vs_init = scipy.ndimage.gaussian_filter(vs_true, sigma=3)
vp_init[:water_depth, :] = vp_true[:water_depth, :] 
vs_init[:water_depth, :] = vs_true[:water_depth, :]
rho_init[:water_depth, :] = rho_true[:water_depth, :]


# processing the water layer
model = IsotropicElasticModel(
                ox,oz,nx,nz,dx,dz,
                vp_init,vs_init,rho_init,
                vp_bound =[vp_true.min(),vp_true.max()],
                vs_bound =[vs_true.min(),vs_true.max()],
                rho_bound=[rho_true.min()-1,rho_true.max()],
                vp_grad = False,vs_grad = True, rho_grad=False,
                auto_update_rho=False,auto_update_vp=False,
                free_surface=free_surface,
                abc_type="PML",abc_jerjan_alpha=0.007,nabc=nabc,
                water_layer_mask= water_mask,
                device=device,dtype=dtype)

model.save(os.path.join(project_path,"model/init_model.npz"))
print(model.__repr__())

model._plot_vp_vs_rho(figsize=(12,5),wspace=0.2,cbar_pad_fraction=0.18,cbar_height=0.04,cmap='coolwarm',save_path=os.path.join(project_path,"model/init_vp_vs_rho.png"),show=True)

# Define source positions in the model
src_z = np.array([2   for i in range(2,nx-2,5)]) 
src_x = np.array([i    for i in range(2,nx-2,5)])
src_t,src_v = wavelet(nt,dt,f0,amp0=1)
src_v = integrate.cumtrapz(src_v, axis=-1, initial=0) #Integrate
source = Source(nt=nt,dt=dt,f0=f0)
for i in range(len(src_x)):
    source.add_source(src_x=src_x[i],src_z=src_z[i],src_wavelet=src_v,src_type="mt",src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))

# Define receiver positions in the model
rcv_z = np.array([2  for i in range(0,nx,1)])
rcv_x = np.array([j   for j in range(0,nx,1)])
receiver = Receiver(nt=nt,dt=dt)
for i in range(len(rcv_x)):
    receiver.add_receiver(rcv_x=rcv_x[i], rcv_z=rcv_z[i], rcv_type="pr")
# survey
survey = Survey(source=source,receiver=receiver)

print(survey.__repr__())
survey.plot(model.vs,cmap='coolwarm',save_path=os.path.join(project_path,"survey/observed_system_init.png"),show=True)

# Plot the wavelet used in the source
source.plot_wavelet(save_path=os.path.join(project_path,"survey/wavelets.png"),show=True)

# Initialize the wave propagator using the specified model and survey configuration
F = ElasticPropagator(model,survey,device=device)
# Retrieve the damping array from the propagator and plot it to visualize boundary conditions
if model.abc_type == "PML":
    bcx = F.bcx
    bcz = F.bcz
    title_param = {'family':'Times New Roman','weight':'normal','size': 15}
    plot_bcx_bcz(bcx,bcz,dx=dx,dz=dz,wspace=0.25,title_param=title_param,cbar_height=0.04,cbar_pad_fraction=-0.05,save_path=os.path.join(project_path,"model/boundary_condition_init.png"),show=False)
else:
    damp = F.damp
    plot_damp(damp)

# load data
d_obs = SeismicData(survey)
d_obs.load(os.path.join(project_path,"waveform/obs_data.npz"))
d_obs.add_noise(0.01)  
print(d_obs.__repr__())

from uniSI.inversion.misfit import Misfit_waveform_L2, Misfit_traveltime
from uniSI.inversion.regularization import regularization_TV_1order

# Set the number of iterations for the inversion process.
iteration = 3500
# optimizer
optimizer   =   torch.optim.AdamW(model.parameters(), lr = 0.7,betas=(0.9,0.999), weight_decay=1e-4)
scheduler   =   torch.optim.lr_scheduler.StepLR(optimizer,step_size=400,gamma=0.9,last_epoch=-1)

# Setup misfit function
# loss_fn = Misfit_waveform_L2(dt=dt)
loss_fn = Misfit_traveltime(dt=dt)
regularization_fn = regularization_TV_1order(nx,nz,dx,dz,step_size=50,gamma=1,device=device,dtype=dtype)

grad_mask = np.ones_like(vp_init)
grad_mask[:water_depth] = 0
gradient_processor_vp = GradProcessor(grad_mask=grad_mask,forw_illumination=False)
gradient_processor_vs = GradProcessor(grad_mask=grad_mask,forw_illumination=False,grad_mute=0,grad_smooth=0,marine_or_land='marine')
gradient_processor = [gradient_processor_vp,gradient_processor_vs]

inversion = ElasticInversion(propagator=F,
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    loss_fn=loss_fn,
                    regularization_fn=regularization_fn,
                    regularization_weights_x=[0,0,0,0,0,0],
                    regularization_weights_z=[0,0,0,0,0,0],
                    obs_data=d_obs,
                    gradient_processor=gradient_processor,
                    waveform_normalize=True,
                    cache_result=True,
                    save_fig_epoch=100,
                    save_fig_path=os.path.join(project_path,"inversion"),
                    inversion_component=["vx","vz"]
                    )

# Run the forward modeling for the specified number of iterations.
inversion.forward(iteration=iteration,fd_order=4,
                    batch_size=None,checkpoint_segments=1,
                    start_iter=0)

# Retrieve the inversion results: updated velocity and loss values.
iter_vp     = inversion.iter_vp
iter_vs     = inversion.iter_vs
iter_rho    = inversion.iter_rho
iter_loss   = inversion.iter_loss

# Save the iteration results to files for later analysis.
np.savez(os.path.join(project_path,"inversion/iter_vp.npz"),data=np.array(iter_vp))
np.savez(os.path.join(project_path,"inversion/iter_vs.npz"),data=np.array(iter_vs))
np.savez(os.path.join(project_path,"inversion/iter_rho.npz"),data=np.array(iter_rho))
np.savez(os.path.join(project_path,"inversion/iter_loss.npz"),data=np.array(iter_loss))

# plot the misfit
plt.figure(figsize=(8,6))
plt.plot(iter_loss,c='k')
plt.xlabel("Iterations", fontsize=12)
plt.ylabel("L2-norm Misfits", fontsize=12)
plt.tick_params(labelsize=12)
plt.savefig(os.path.join(project_path,"inversion/misfit.png"),bbox_inches='tight',dpi=100)
plt.show()

# plot the initial model and inverted resutls
plt.figure(figsize=(12,8))
plt.subplot(121)
plt.imshow(vs_true,cmap='jet_r')
plt.colorbar(orientation='horizontal')
plt.subplot(122)
plt.imshow(iter_vs[-1],cmap='jet_r')
plt.colorbar(orientation='horizontal')
plt.savefig(os.path.join(project_path,"inversion/inverted_vs.png"),bbox_inches='tight',dpi=100)
plt.show()