import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import integrate
import sys
import os
sys.path.append("../../")
from uniSI.propagator  import *
from uniSI.model       import *
from uniSI.view        import *
from uniSI.utils       import *
from uniSI.survey      import *
from uniSI.inversion   import *

from mpl_toolkits.axes_grid1 import make_axes_locatable
from skimage.metrics import structural_similarity as ssim
import matplotlib.ticker as ticker

project_path = "./data/uniformgrid"
if not os.path.exists(os.path.join(project_path,"model")):
    os.makedirs(os.path.join(project_path,"model"))
if not os.path.exists(os.path.join(project_path,"waveform")):
    os.makedirs(os.path.join(project_path,"waveform"))
if not os.path.exists(os.path.join(project_path,"survey")):
    os.makedirs(os.path.join(project_path,"survey"))

# Device and data type configuration
device = "cuda:0"         # Specify the GPU device to use (e.g., "cuda:0")
dtype = torch.float32     

# Model geometry and grid parameters
ox, oz = 0, 0             # Origin coordinates in x and z directions
nz, nx = 70, 70          # Number of grid points in z and x directions

dx, dz = 10, 10   # Grid spacing in x and z directions (meters)
nt, dt = 1000, 0.001  # Number of time steps and time interval (seconds)
nabc = 120

# Source and boundary condition settings
f0 = 15                   # Initial source frequency in Hz
free_surface = False      # Enable free surface boundary condition

# Note: If you change dx and dz, be sure to adjust nz, nx, and nabc accordingly.

all_model = np.load("datasets/flatvel-B/data_val.npy")
SSIM_dict = {}
SSIM_dict['init'] = np.zeros(all_model.shape[0])
SSIM_dict['final'] = np.zeros(all_model.shape[0])
SSIM_dict['best'] = np.zeros(all_model.shape[0])

for i_model in range(all_model.shape[0]):
    print("Processing model index of total {}: {}".format(i_model,all_model.shape[0]))
    ## Generate the water layer matrix
    water_layer = 0
    # Load the OpenFWI model dataset
    vp_true = np.load("datasets/flatvel-B/data_val.npy")[i_model,:,:,:]
    vp_true = vp_true.squeeze()
    rho_true = np.ones_like(vp_true) ## OpenFWI is density-constant

    # Initialize the AcousticModel with parameters and properties
    model = AcousticModel(ox, oz, nx, nz, dx, dz,
                        vp_true, rho_true,
                        vp_grad=False,
                        free_surface=free_surface,
                        abc_type="PML",
                        abc_jerjan_alpha=0.007,
                        nabc=nabc,
                        device=device,
                        dtype=dtype)

    # Print the model's representation for verification.
    print(model.__repr__())

    # Define source positions in the model

    src_x = np.linspace(0, nz-1, 5) # X-coordinates for sources
    src_z = np.zeros_like(src_x)  # Z-coordinates for sources

    # Generate wavelet for the source
    src_t, src_v = wavelet(nt, dt, f0, amp0=1, t0=1.088/f0)  # Create time and wavelet amplitude
    src_v = integrate.cumtrapz(src_v, axis=-1, initial=0)  # Integrate wavelet to get velocity

    source = Source(nt=nt, dt=dt, f0=f0)  # Initialize source object

    # Method 1: Add multiple sources at once (commented out)
    # source.add_sources(src_x=src_x, src_z=src_z, src_wavelet=src_v, src_type='mt', src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))

    # Method 2: Loop through each source position to add them individually
    for i in range(len(src_x)):
        source.add_source(src_x=src_x[i], src_z=src_z[i], src_wavelet=src_v, src_type="mt", src_mt=np.array([[1,0,0],[0,1,0],[0,0,1]]))


    # Define receiver positions in the model
    rcv_z = np.array([0 for i in range(0, nx, 1)])  # Z-coordinates for receivers
    rcv_x = np.array([j for j in range(0, nx, 1)])  # X-coordinates for receivers

    receiver = Receiver(nt=nt, dt=dt)  # Initialize receiver object

    # Method 1: Add all receivers at once (commented out)
    # receiver.add_receivers(rcv_x=rcv_x, rcv_z=rcv_z, rcv_type='pr')

    # Method 2: Loop through each receiver position to add them individually
    for i in range(len(rcv_x)):
        receiver.add_receiver(rcv_x=rcv_x[i], rcv_z=rcv_z[i], rcv_type="pr")

    # Create a survey object using the defined source and receiver
    survey = Survey(source=source, receiver=receiver)

    # Print a representation of the survey object to check its configuration
    print(survey.__repr__())

    # Plot the wavelet used in the source
    source.plot_wavelet(save_path=os.path.join(project_path, "survey/wavelets.png"))

    # Plot the survey configuration over the velocity model
    survey.plot(model.vp, cmap='coolwarm', save_path=os.path.join(project_path, "survey/observed_system.png"))

    # Initialize the wave propagator using the specified model and survey configuration
    F = AcousticPropagator(model, survey, ifvisualWave=False, project_path=project_path, device=device)

    # Retrieve the damping array from the propagator and plot it to visualize boundary conditions
    damp = F.damp
    plot_damp(damp, save_path=os.path.join(project_path, "model/boundary_condition.png"))

    # Perform the forward propagation to record waveforms
    record_waveform = F.forward()

    # Extract recorded pressure wavefield and particle velocities
    rcv_p = record_waveform["p"]  # Recorded pressure wavefield
    rcv_u = record_waveform["u"]  # Recorded particle velocity in x-direction
    rcv_w = record_waveform["w"]  # Recorded particle velocity in z-direction

    # Extract forward wavefields for analysis
    forward_wavefield_p = record_waveform["forward_wavefield_p"]  # Forward pressure wavefield
    forward_wavefield_u = record_waveform["forward_wavefield_u"]  # Forward particle velocity wavefield in x
    forward_wavefield_w = record_waveform["forward_wavefield_w"]  # Forward particle velocity wavefield in z

    # Create a SeismicData object to store observed data from the survey
    d_obs = SeismicData(survey)

    # Record the waveform data into the SeismicData object
    d_obs.record_data(record_waveform)

    import scipy.ndimage
    vp_init = scipy.ndimage.gaussian_filter(vp_true, sigma=2.5)
    rho_init = rho_true # Constant for OpenFWI

    # Initialize model
    model = AcousticModel(ox, oz, nx, nz, dx, dz,
                        vp_init, rho_init,
                        vp_grad=True, auto_update_rho=False,
                        free_surface=free_surface,
                        abc_type="PML",
                        abc_jerjan_alpha=0.007,
                        nabc=nabc,
                        device=device,
                        dtype=dtype)

    # Initialize the wave propagator using the specified model and survey configuration
    F = AcousticPropagator(model, survey, device=device)
    # Retrieve the damping array from the propagator and plot it to visualize boundary conditions
    damp = F.damp
    plot_damp(damp, save_path=os.path.join(project_path, "model/boundary_condition.png"))


    # Save the initialized model to a file for later use.
    model.save(os.path.join(project_path, "model/init_model.npz"))

    np.savez(os.path.join(project_path, "model/true_model.npz"), vp=vp_true)

    # Print the model's representation for verification.
    print(model.__repr__())


    # Import the L2 misfit function for waveform inversion.
    from uniSI.inversion.misfit import Misfit_global_correlation

    torch.cuda.empty_cache()
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

    iteration = 600 #1500

    # Initialize the optimizer (Adam) for model parameters with a learning rate.
    optimizer = torch.optim.Adam(model.parameters(), lr=8)

    # Set up a learning rate scheduler to adjust the learning rate over time.
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5, last_epoch=-1)

    loss_fn = Misfit_global_correlation(dt=dt)

    water_layer_matrix = np.zeros_like(vp_init)
    water_layer_matrix[:water_layer,:] = 1
    grad_mask = 1.0 - water_layer_matrix
    gradient_processor = GradProcessor(grad_mask=grad_mask,grad_mute=water_layer)

    # Print inversion save path
    print(os.path.join(project_path, "inversion"))

    # Initialize the acoustic full waveform inversion (Inversion) object.
    inversion = AcousticInversion(propagator=F,
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    loss_fn=loss_fn,
                    obs_data=d_obs,
                    gradient_processor=gradient_processor,
                    waveform_normalize=True, 
                    cache_result=True,  
                    save_fig_epoch=50,  
                    save_fig_path=os.path.join(project_path, "inversion"),
                    min_improvement=1e-12)

    # Run the forward modeling for the specified number of iterations.
    inversion.forward(iteration=iteration, batch_size=None, checkpoint_segments=1)

    # Retrieve the inversion results: updated velocity and loss values.
    iter_vp = inversion.iter_vp
    iter_rho = inversion.iter_rho
    iter_loss = inversion.iter_loss

    # Save the iteration results to files for later analysis.
    save_dir = os.path.join(project_path, "inversion")
    os.makedirs(save_dir, exist_ok=True)  

    init_model = np.load(os.path.join(project_path,"model/init_model.npz"))
    init_v = init_model["vp"]
    true_model = np.load(os.path.join(project_path,"model/true_model.npz"), allow_pickle=True)
    true_v = true_model["vp"]

    iter_vp = [vp for vp in iter_vp if not np.isnan(vp).all()]

    vmax = np.max(true_v)
    vmin = np.min(true_v)
                  
    # Save the image of vp_init. Note only image, not any label or colorbar.
    plt.figure(figsize=(20,16))
    plt.imshow(init_v, cmap='coolwarm',interpolation='nearest', vmin=vmin, vmax=vmax)
    plt.axis('off')
    plt.savefig(os.path.join(project_path, "inversion/vp_init_"+str(i_model)+".pdf"), bbox_inches='tight', format='pdf')
    plt.close()
    plt.figure(figsize=(20,16))
    plt.imshow(iter_vp[-1], cmap='coolwarm',interpolation='nearest', vmin=vmin, vmax=vmax)
    plt.axis('off')
    plt.savefig(os.path.join(project_path, "inversion/vp_result_"+str(i_model)+".pdf"), bbox_inches='tight', format='pdf')
    plt.close()
    plt.figure(figsize=(20,16))
    plt.imshow(true_v, cmap='coolwarm',interpolation='nearest', vmin=vmin, vmax=vmax)
    plt.axis('off')
    plt.savefig(os.path.join(project_path, "inversion/vp_true_"+str(i_model)+".pdf"), bbox_inches='tight', format='pdf')
    plt.close()

    def single_scale_ssim(pred, true, win_size=7):
        """
        Compute single-scale Structural Similarity Index (SSIM).

        Args:
            pred: Predicted velocity model.
            true: True velocity model.
            win_size: Window size for computing SSIM (must be an odd number).

        Returns:
            SSIM score.
        """
        # Normalize images to [0, 1]
        pred_norm = (pred - pred.min()) / (pred.max() - pred.min() + 1e-12)
        true_norm = (true - true.min()) / (true.max() - true.min() + 1e-12)
        
        # Ensure win_size is odd and not larger than the image dimensions
        win_size_adjusted = min(win_size, (min(pred_norm.shape) // 2) * 2 - 1)
        
        return ssim(pred_norm, true_norm,
                    win_size=win_size_adjusted,
                    data_range=1.0,
                    gaussian_weights=True,
                    use_sample_covariance=False)

    def calculate_ssim(iter_vp, true_v):
        """Calculate SSIM scores for each iteration."""
        ssim_scores = []

        for vp in iter_vp:
            ssim_score = single_scale_ssim(vp, true_v)

            if not np.isnan(ssim_score):
                ssim_scores.append(ssim_score)
            
        return np.array(ssim_scores)

    def visualize_results(iter_vp, true_v, init_v, ssim_scores, save_path):
        best_ssim_idx = np.argmax(ssim_scores)
        
        fig = plt.figure(figsize=(15, 6))
        gs = fig.add_gridspec(2, 3, 
                            width_ratios=[1, 1, 0.05], 
                            height_ratios=[1, 1],      
                            wspace=0.15, hspace=0.15)   
        
        vmin, vmax = np.min(true_v), np.max(true_v)
        cmap = 'coolwarm'
        
        ax1 = fig.add_subplot(gs[0, 0])
        im1 = ax1.imshow(true_v, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
        ax1.set_title('True Velocity Model', fontsize=12)
        
        ax2 = fig.add_subplot(gs[0, 1])
        im2 = ax2.imshow(iter_vp[best_ssim_idx], cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
        ax2.set_title(f'Best Model (Iter {best_ssim_idx})', fontsize=12)
    
        ax3 = fig.add_subplot(gs[1, 0])
        im3 = ax3.imshow(init_v, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
        ax3.set_title('Initial Model', fontsize=12)
        
        ax4 = fig.add_subplot(gs[1, 1])
        ax4.plot(ssim_scores, 'b-o', lw=1.5, markersize=6)
        ax4.plot(best_ssim_idx, ssim_scores[best_ssim_idx], 'r*', 
                markersize=12, label=f'Best: {ssim_scores[best_ssim_idx]:.3f}')
        ax4.set_xlabel('Iteration', fontsize=11)
        ax4.set_title('SSIM Evolution', fontsize=12)
        ax4.grid(True, alpha=0.3)
        ax4.legend(loc='lower right', fontsize=10)
        ax4.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))
        
        cax = fig.add_subplot(gs[:, 2])
        plt.colorbar(im1, cax=cax)
        
        for ax in [ax1, ax2, ax3]:
            ax.set_xticks([])
            ax.set_yticks([])
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')


    ssim_scores = calculate_ssim(iter_vp, true_v)

    save_path = os.path.join(project_path, "inversion/ssim_analysis.png")
    visualize_results(iter_vp, true_v, init_v, ssim_scores, save_path)


    best_ssim_idx = np.argmax(ssim_scores)
    print("\nSSIM Statistics:")
    print(f"Best SSIM: {ssim_scores[best_ssim_idx]:.4f} at iteration {best_ssim_idx}")
    print(f"Initial SSIM: {ssim_scores[0]:.4f}")
    print(f"Final SSIM: {ssim_scores[-1]:.4f}")

    SSIM_dict['init'][i_model] = ssim_scores[0]
    SSIM_dict['final'][i_model] = ssim_scores[-1]
    SSIM_dict['best'][i_model] = ssim_scores[best_ssim_idx]

    # Save the SSIM dictionary to a file
    np.savez(os.path.join(project_path, "SSIM_dict.npz"), **SSIM_dict)