from src.logger_n_config import init_logging, load_config, parse_arguments
from src.mic_sim import Microscope
import logging

from atomai.utils import get_coord_grid, extract_patches_and_spectra
from scalarizer_preacquired import scalarizer_peak_hard_coded
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
if os.environ.get("CI_SMOKE"):
    SMOKE = True
else:
    SMOKE = False
import gpax
import numpy as np



def main():
    # Load configuration settings
    args = parse_arguments()
    config = load_config(args.config)
    # Initialize logging
    
    # get info from config
    spectralavg = config['settings']['spectralavg']
    save_dir = config['settings']['save_dir']
    data_path = config['settings']['data_path']


    init_logging(save_dir=save_dir, config=config)
    logging.info(f"Directory {save_dir} already exists.")
    data_path = config['settings']['data_path']
    mic = Microscope(data_path)
    img = mic.survey_img
    specim = mic.spec_img
    e_ax = mic.energy_axis
    imscale = mic.scale
    
    
    plt.imshow(img, origin='lower')
    plt.colorbar()
    plt.savefig(os.path.join(save_dir, "survey_image.png"))
    plt.close()
    logging.info(f"Survey image saved to {os.path.join(save_dir, 'survey_image.png')}")

    window_size = config['settings']['window_size']

    
    coordinates = get_coord_grid(img, step=1, return_dict=False)
    features, targets, indices = extract_patches_and_spectra(
        specim,
        img,
        coordinates=coordinates,
        window_size=window_size,
        avg_pool=16
    )
    
    print(features.shape, targets.shape)
    logging.info(f"Features shape: {features.shape}, Targets shape: {targets.shape}")

    # normalize data:
    norm_ = lambda x: (x - x.min()) / x.ptp()
    features, targets = norm_(features), norm_(targets)
    
    
    peaks_all, features_all, indices_all = scalarizer_peak_hard_coded(targets, features, indices)
    _, ax = plt.subplots()
    ax.scatter(indices_all[:, 1], indices_all[:, 0], c=peaks_all)
    ax.set_title('Plasmon peak intensities')
    ax.set_aspect('equal')
    plt.savefig(os.path.join(save_dir, "plasmon_peak_intensities.png"))
    plt.close()
    logging.info(f"Plasmon peak intensities saved to {os.path.join(save_dir, 'plasmon_peak_intensities.png')}")
    
    
    # include plot scalarizer later
    logging.info(f"Peaks_all shape: {peaks_all.shape}, Features_all shape: {features_all.shape}, Indices_all shape: {indices_all.shape}")
    
    ## Active learning:

    n, d1, d2 = features_all.shape
    X = features_all.reshape(n, d1*d2)
    y = peaks_all
    print(X.shape, y.shape)
    logging.info(f"X shape: {X.shape}, y shape: {y.shape}")
    

    # use only 0.02% of grid data points as initial training points

    (
        X_measured,
        X_unmeasured,
        y_measured,
        y_unmeasured,
        indices_measured,
        indices_unmeasured
    ) = train_test_split(
        X,
        y,
        indices_all,
        test_size=0.998,
        shuffle=True,
        random_state=1
    )

    
    seed_points = len(X_measured)
    
    # plot the seed points
    plt.imshow(img, origin='lower')
    plt.scatter(indices_measured[:, 1], indices_measured[:, 0], c="r", s=10)
    plt.savefig(os.path.join(save_dir, "seed_points.png"))
    plt.close()
    logging.info(f"Seed points saved to {os.path.join(save_dir, 'seed_points.png')}")    

    def plot_result_dkl(indices, obj):
        fig, ax = plt.subplots(1, 1, figsize=(3, 3))
        ax.scatter(indices[:, 1], indices[:, 0], s=32, c=obj, marker='s')
        next_point = indices[obj.argmax()]# maximize the acquisition function
        ax.scatter(next_point[1], next_point[0], marker='x', c='k')
        ax.set_title("Acquisition function values")
        return fig

    data_dim = X_measured.shape[-1]

    budget = config['settings']['budget']

    exploration_steps = budget if not SMOKE else 5

    key1, key2 = gpax.utils.get_keys()

    for e in range(exploration_steps):
        print("{}/{}".format(e+1, exploration_steps))

        # update GP posterior
        dkl = gpax.viDKL(data_dim, 2)

        # you may decrease step size and increase number of steps
        # (e.g. to 0.005 and 1000) for more stable performance
        dkl.fit(
            key1, X_measured, y_measured, num_steps=100, step_size=0.05
        )

        # Compute UCB acquisition function
        # input beta is the exploration parameter

        import pdb; pdb.set_trace()
        obj = gpax.acquisition.UCB(key2, dkl, X_unmeasured, beta=0.25, maximize=True)
        user_input_beta = input("Enter something: ")
        print("You entered:", user_input_beta)
        # Select next point to "measure"
        next_point_idx = obj.argmax()

        # Do "measurement"
        measured_point = y_unmeasured[next_point_idx]

        # Plot current result
        fig = plot_result_dkl(indices_unmeasured, obj)
        plt.savefig(os.path.join(save_dir, f"acquisition_function_{e}.png"))
        plt.close()
        logging.info(f"Acquisition function plot saved to {os.path.join(save_dir, f'acquisition_function_{e}.png')}")

        # Update the arrays of measured/unmeasured points
        X_measured = np.append(X_measured, X_unmeasured[next_point_idx][None], 0)
        X_unmeasured = np.delete(X_unmeasured, next_point_idx, 0)
        y_measured = np.append(y_measured, measured_point)
        y_unmeasured = np.delete(y_unmeasured, next_point_idx)
        indices_measured = np.append(indices_measured, indices_unmeasured[next_point_idx][None], 0)
        indices_unmeasured = np.delete(indices_unmeasured, next_point_idx, 0)
        
    ## post plotting
    plt.imshow(img, origin="lower", cmap='gray')
    plt.scatter(
        indices_measured[seed_points:, 1],
        indices_measured[seed_points:, 0],
        c=np.arange(len(indices_measured[seed_points:])),
        s=50,
        cmap="Reds"
    )
    plt.colorbar()
    plt.savefig(os.path.join(save_dir, "final_points.png"))
    plt.close()
    logging.info(f"Final points saved to {os.path.join(save_dir, 'final_points.png')}")
if __name__ == "__main__":
    main()