#%%
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
#%%
result_dir = 'exps/inference/navier-stokes-ds2/Blade/best512'
# Load the data
results = torch.load(os.path.join(result_dir, 'result_0.pt'))
print(results['recon'].shape)
#%%
norm = plt.Normalize(vmin=-10, vmax=10)
# Plot the results
plt.imshow(results['recon'][0, 0], cmap='jet', norm=norm)
plt.colorbar()
plt.title('Reconstructed')
# %%
plt.imshow(results['target'][0, 0], cmap='jet', norm=norm)
plt.colorbar()
plt.title('Target')
# %%
