import os
if os.environ.get("CI_SMOKE"):
    SMOKE = True
else:
    SMOKE = False



try:
    # For use on Google Colab
    import gpax

except ImportError:
    # For use locally (where you're using the local version of gpax)
    print("Assuming notebook is being run locally, attempting to import local gpax module")
    import sys
    sys.path.append("..")
    import gpax



from warnings import filterwarnings

import numpy as np
import matplotlib.pyplot as plt
import math

from scipy.signal import find_peaks
from sklearn.model_selection import train_test_split

from atomai.utils import get_coord_grid, extract_patches_and_spectra

gpax.utils.enable_x64()

filterwarnings("ignore", module="haiku._src.data_structures")

import matplotlib as mpl

mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 200

## Prepare data

!wget -qq https://www.dropbox.com/s/1tguc2zraiyxg7h/Plasmonic_EELS_FITO0_edgehole_01.npy


loadedfile = np.load("Plasmonic_EELS_FITO0_edgehole_01.npy", allow_pickle=True).tolist()
for key, value in loadedfile.items():
  print(key, value.shape)
  
  
img = loadedfile['image']
specim = loadedfile['spectrum image']
e_ax = loadedfile['energy axis']
imscale = loadedfile['scale']


window_size = 16

coordinates = get_coord_grid(img, step=1, return_dict=False)
features, targets, indices = extract_patches_and_spectra(
    specim,
    img,
    coordinates=coordinates,
    window_size=window_size,
    avg_pool=16
)

features.shape, targets.shape



# Scalarizer for the targets
k = 1
peak_data = find_peaks(targets[k], width=5)
peak_pos, peak_int = peak_data[0][0], peak_data[1]['prominences'][0]

fig, ax = plt.subplots(1, 1, figsize=(6, 2))
ax.plot(targets[k], zorder=0)
print(peak_pos, peak_int)
ax.scatter(peak_pos, peak_int, marker='x', s=50, c='k', zorder=1)
plt.show()

peaks_all, features_all, indices_all = [], [], []
for i, t in enumerate(targets):
    peak = find_peaks(t, width=5)[1]["prominences"]
    if len(peak) == 0:
        continue
    peaks_all.append(np.array([peak[0]]) if len(peak) > 1 else peak)
    features_all.append(features[i])
    indices_all.append(indices[i])
peaks_all = np.concatenate(peaks_all)
features_all = np.array(features_all)
indices_all = np.array(indices_all)


_, ax = plt.subplots()
ax.scatter(indices_all[:, 1], indices_all[:, 0], c=peaks_all)
ax.set_title('Plasmon peak intensities')
ax.set_aspect('equal')
plt.show()


##### Active learning:

n, d1, d2 = features_all.shape
X = features_all.reshape(n, d1*d2)
y = peaks_all
X.shape, y.shape

# use only 0.02% of grid data points as initial training points
(
    X_measured,
    X_unmeasured,
    y_measured,
    y_unmeasured,
    indices_measured,
    indices_unmeasured
) = train_test_split(
    X,
    y,
    indices_all,
    test_size=0.998,
    shuffle=True,
    random_state=1
)

seed_points = len(X_measured)

plt.figure(figsize=(3, 3))
plt.scatter(indices_measured[:, 1], indices_measured[:, 0], s=50, c=y_measured)
plt.show()

def plot_result(indices, obj):
    fig, ax = plt.subplots(1, 1, figsize=(3, 3))
    ax.scatter(indices[:, 1], indices[:, 0], s=32, c=obj, marker='s')
    next_point = indices[obj.argmax()]
    ax.scatter(next_point[1], next_point[0], marker='x', c='k')
    ax.set_title("Acquisition function values")
    plt.show()
    
data_dim = X_measured.shape[-1]

exploration_steps = 80 if not SMOKE else 5

key1, key2 = gpax.utils.get_keys()

for e in range(exploration_steps):
    print("{}/{}".format(e+1, exploration_steps))

    # update GP posterior
    dkl = gpax.viDKL(data_dim, 2)

    # you may decrease step size and increase number of steps
    # (e.g. to 0.005 and 1000) for more stable performance
    dkl.fit(
        key1, X_measured, y_measured, num_steps=100, step_size=0.05
    )

    # Compute UCB acquisition function
    obj = gpax.acquisition.UCB(key2, dkl, X_unmeasured, beta=0.25, maximize=True)

    # Select next point to "measure"
    next_point_idx = obj.argmax()

    # Do "measurement"
    measured_point = y_unmeasured[next_point_idx]

    # Plot current result
    plot_result(indices_unmeasured, obj)

    # Update the arrays of measured/unmeasured points
    X_measured = np.append(X_measured, X_unmeasured[next_point_idx][None], 0)
    X_unmeasured = np.delete(X_unmeasured, next_point_idx, 0)
    y_measured = np.append(y_measured, measured_point)
    y_unmeasured = np.delete(y_unmeasured, next_point_idx)
    indices_measured = np.append(indices_measured, indices_unmeasured[next_point_idx][None], 0)
    indices_unmeasured = np.delete(indices_unmeasured, next_point_idx, 0)
    

plt.imshow(img, origin="lower", cmap='gray')
plt.scatter(
    indices_measured[seed_points:, 1],
    indices_measured[seed_points:, 0],
    c=np.arange(len(indices_measured[seed_points:])),
    s=50,
    cmap="Reds"
)
plt.colorbar()
plt.show()

plt.scatter(indices_all[:, 1], indices_all[:, 0], c=peaks_all, cmap='jet', alpha=0.5)
plt.scatter(
    indices_measured[seed_points:, 1],
    indices_measured[seed_points:, 0],
    c=np.arange(len(indices_measured[seed_points:])),
    s=50,
    cmap="Greens"
)
plt.colorbar()
plt.show()