import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import integrate
import pickle
import sys
import os
sys.path.append("../../")
from uniSI.propagator  import *
from uniSI.model       import *
from uniSI.view        import *
from uniSI.utils       import *
from uniSI.survey      import *
from uniSI.inversion      import *

project_path = "./data/uniformgrid"
if not os.path.exists(os.path.join(project_path,"model")):
    os.makedirs(os.path.join(project_path,"model"))
if not os.path.exists(os.path.join(project_path,"waveform")):
    os.makedirs(os.path.join(project_path,"waveform"))
if not os.path.exists(os.path.join(project_path,"survey")):
    os.makedirs(os.path.join(project_path,"survey"))

def load_pickle_data(file_path):
    """Load a pickle file with latin1 encoding."""
    with open(file_path, 'rb') as f:
        data = pickle.load(f, encoding='latin1')
    return data

vel_file = 'datasets/nankaiVs/vel.pickle' # Original velocity model
vco_file = 'datasets/nankaiVs/vco.pickle' # Checkerboard velocity model
den_file = 'datasets/nankaiVs/den.pickle'
sta_file = 'datasets/nankaiVs/rrsta.pickle'

vel_dict = load_pickle_data(vel_file)
vco_dict = load_pickle_data(vco_file)
den_dict = load_pickle_data(den_file)
sta_dict = load_pickle_data(sta_file)

spacing = vco_dict['spacing']*1e3
print("Spacing: ", spacing)

vs_orig = vel_dict['velocity_model']*1000
vs_check = vco_dict['velocity_model']*1000
den_orig = den_dict['density_model']*1e3
x_grids = vco_dict['x_grid']*1e3
x_origin = 15000 # 15km start tomography
x_grids = x_grids-x_origin
z_grids = vco_dict['z_grid']*1e3


vs_orig  = vs_orig[:,int(x_origin/spacing):]
vs_check = vs_check[:,int(x_origin/spacing):]
den_orig  = den_orig[:,int(x_origin/spacing):]

spacing = vco_dict['spacing']*1e3

device = "cuda:0"         # Specify the GPU device
dtype = torch.float32     # Set data type to 32-bit floating point
ox, oz = 0, 0             # Origin coordinates for x and z directions
nz, nx = vs_orig.shape[0], vs_orig.shape[1]  # Number of grid points in z and x directions
print(nx,nz)
dx, dz = spacing, spacing  # Grid spacing in x and z directions (in meters)
nt, dt = 8000, 0.01  # Number of time steps and time step size (in seconds)
nabc = 30
f0 = 0.2                    # Initial frequency in Hz
free_surface = True       # Enable free surface boundary condition

stations = sta_dict['station_coordinates']*1e3
stations[:,0] = stations[:,0]-x_origin
stations = np.array(stations/dx,dtype=int)

water_mask = np.zeros((nz, nx), dtype=bool)  

np.savez(os.path.join(project_path, "model/init_model.npz"), vs=vs_orig)
np.savez(os.path.join(project_path, "model/true_model.npz"), vs=vs_check)


# Initialize the SHModel with parameters and properties
model = SHModel(ox, oz, nx, nz, dx, dz,
                      vs=vs_orig, rho=den_orig,  
                      vs_grad=True, auto_update_rho=False,
                      free_surface=free_surface,
                      abc_type="PML",
                      abc_jerjan_alpha=0.007,
                      nabc=nabc,
                      device=device,
                      dtype=dtype)

# Print the model's representation for verification.
print(model.__repr__())

# Plot the primary wave velocity (vp) and density (rho) of the model
model._plot_mu_rho(figsize=(16,5),wspace=0.15,cbar_pad_fraction=0.1, cmap='coolwarm',save_path=os.path.join(project_path,"model/true_vp_rho.png"))

# Define source positions in the model
src_z = stations[:, 1]+1  # Z-coordinates for sources
src_x = stations[:, 0]  # X-coordinates for sources

# Generate wavelet for the source
src_t, src_v = wavelet(nt, dt, f0, amp0=1)  # Create time and wavelet amplitude
#src_v = integrate.cumtrapz(src_v, axis=-1, initial=0)  # Integrate wavelet to get velocity

source = Source(nt=nt, dt=dt, f0=f0)  # Initialize source object

# SH wave
mt_sh = np.array([[0, 1, 0],
                  [1, 0, 0],
                  [0, 0, 0]])

# Method 1: Add multiple sources at once (commented out)
# source.add_sources(src_x=src_x, src_z=src_z, src_wavelet=src_v, src_type='mt', src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))

# Method 2: Loop through each source position to add them individually
for i in range(len(src_x)):
    source.add_source(src_x=src_x[i], src_z=src_z[i], src_wavelet=src_v, src_type="mt", src_mt=mt_sh)

# Define receiver positions in the model
rcv_z = stations[:, 1]+1  # Z-coordinates for virtual sources
rcv_x = stations[:, 0]  # X-coordinates for virtual sources

receiver = Receiver(nt=nt, dt=dt)  # Initialize receiver object

# Method 1: Add all receivers at once (commented out)
# receiver.add_receivers(rcv_x=rcv_x, rcv_z=rcv_z, rcv_type='pr')

# Method 2: Loop through each receiver position to add them individually
for i in range(len(rcv_x)):
    receiver.add_receiver(rcv_x=rcv_x[i], rcv_z=rcv_z[i], rcv_type="vy")

# Create a survey object using the defined source and receiver
survey = Survey(source=source, receiver=receiver)

# Print a representation of the survey object to check its configuration
print(survey.__repr__())

# Plot the survey configuration over the velocity model
survey.plot(model.vs, cmap='coolwarm', save_path=os.path.join(project_path, "survey/observed_system.png"))

# Initialize the wave propagator using the specified model and survey configuration
F = SHPropagator(model, survey, ifvisualWave=False, projectpath=project_path, device=device)

# Retrieve the damping array from the propagator and plot it to visualize boundary conditions
damp = F.damp
plot_damp(damp, save_path=os.path.join(project_path, "model/boundary_condition.png"))

# Create an instance of SeismicData using the survey object.
d_obs = SeismicData(survey)

# Load observed waveform data from a specified file.
d_obs.load(os.path.join(project_path, "waveform/obs_data.npz"))

# Print a summary representation of the observed seismic data.
print(d_obs.__repr__())

# Import the L2 misfit function for waveform inversion.
from uniSI.inversion.misfit import Misfit_waveform_L2,Misfit_global_correlation, Misfit_traveltime

torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

iteration = 350 # 300

# Initialize the optimizer (Adam) for model parameters with a learning rate.
optimizer = torch.optim.Adam(model.parameters(), lr=4) #

# Set up a learning rate scheduler to adjust the learning rate over time.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.8, last_epoch=-1)

# Configure the misfit function to compute the loss based on the observed data.
#loss_fn = Misfit_waveform_L2(dt=dt)
loss_fn = Misfit_global_correlation(dt=dt)
# loss_fn = Misfit_traveltime(dt=dt)


grad_mask = np.ones_like(vs_orig)
# Free Surface
#grad_mask[:2,:] = 0
grad_mask[:2,:] = 0
grad_mask[-1,:] = 0
grad_mask[:,0] = 0
grad_mask[:,-1] = 0
gradient_processor = GradProcessor(grad_mask=grad_mask,forw_illumination=True) 


# Print inversion save path
print(os.path.join(project_path, "inversion"))

# Initialize the acoustic full waveform inversion (Inversion) object.
inversion = SHInversion(propagator=F,
                  model=model,
                  optimizer=optimizer,
                  scheduler=scheduler,
                  loss_fn=loss_fn,
                  obs_data=d_obs,
                  gradient_processor=gradient_processor,
                  waveform_normalize=True, 
                  cache_result=True,  
                  save_fig_epoch=10,  
                  save_fig_path=os.path.join(project_path, "inversion"),
                  min_improvement=1e-7)

# Run the forward modeling for the specified number of iterations.
inversion.forward(iteration=iteration, batch_size=3, checkpoint_segments=1)



# Retrieve the inversion results: updated velocity and loss values.
iter_vs = inversion.iter_vs
iter_rho = inversion.iter_rho
iter_loss = inversion.iter_loss

# Save the iteration results to files for later analysis.
save_dir = os.path.join(project_path, "inversion")
os.makedirs(save_dir, exist_ok=True)  
np.savez(os.path.join(project_path, "inversion/iter_vs.npz"), data=np.array(iter_vs))
np.savez(os.path.join(project_path, "inversion/iter_loss.npz"), data=np.array(iter_loss))


from uniSI.inversion.misfit import Misfit_waveform_L2,Misfit_global_correlation

# plot the misfit
plt.figure(figsize=(8,6))
plt.plot(iter_loss,c='k')
plt.xlabel("Iterations", fontsize=12)
plt.ylabel("L2-norm Misfits", fontsize=12)
plt.tick_params(labelsize=12)
plt.savefig(os.path.join(project_path,"inversion/misfit.png"),bbox_inches='tight',dpi=100)
plt.show()