import streamlit as st
from src.logger_n_config import init_logging, load_config, parse_arguments
from src.mic_sim import Microscope
import logging
from atomai.utils import get_coord_grid, extract_patches_and_spectra, extract_subimages
from src.scalarizer_live import scalarizer_sum, scalarizer_max, scalarizer_mean, scalarizer_inv_mean
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import numpy as np
import gpax





######################################### Autoscript related functions #########################################

from autoscript_tem_microscope_client import TemMicroscopeClient
from autoscript_tem_microscope_client.enumerations import *
from autoscript_tem_microscope_client.structures import *
import numpy as np
# General packages
import os, time, sys, math

# General image processing packages
from matplotlib import pyplot as plot
import numpy as np
import cv2 as cv

#########################################

# Create microscope autoscript client and connect to the HPC






        

# Initialize session state variables
if "exploration_step" not in st.session_state:
    st.session_state.exploration_step = 0
    st.session_state.X_measured = None
    st.session_state.X_unmeasured = None
    st.session_state.y_measured = None
    st.session_state.y_unmeasured = None
    st.session_state.indices_measured = None
    st.session_state.indices_unmeasured = None
    st.session_state.img = None
    st.session_state.save_dir = 'out_data_sl/'  # Default value for save_dir
    st.session_state.steps_with_same_beta = 1  # Default steps with same beta
    st.session_state.seed_points = None
    st.session_state.budget = 20  # Default value for budget
    st.session_state.HAADF_exposure = 4e-5  # Default value for spectralavg
    st.session_state.HAADF_resolution = 256  # Default value for resolution
    st.session_state.data_path = 'in_data_sl/Plasmonic_EELS_FITO0_edgehole_01.npy'  # Default value for data_path
    st.session_state.window_size = 16  # Default value for window_size
    st.session_state.image_data = None
    st.session_state.sc_f = None


@st.cache_data
def load_image_and_features(img, window_size):
    coordinates = get_coord_grid(img, step = 1, return_dict=False)
    features_all, coords, _ = extract_subimages(img, coordinates, window_size)
    features_all = features_all[:,:,:,0]
    coords = np.array(coords, dtype=int)
    norm_ = lambda x: (x - x.min()) / x.ptp()
    features = norm_(features_all)
    return img, features, coords



def run_acquisition_for_position(path, image_data, microscope, position, eds_detector_name):
    """Run the entire acquisition process for a given position."""
    print(f"Running acquisition for position {position}")
    microscope.optics.paused_scan_beam_position = [float(position[0]), float(position[1])]
    settings = configure_acquisition(eds_detector_name)
    spectrum = acquire_and_plot_spectrum(path, image_data, microscope, settings)
    return spectrum
    
    
def configure_acquisition(eds_detector_name, dispersion=5, shaping_time=3e-6, exposure_time=2, exposure_time_type=ExposureTimeType.LIVE_TIME):
    """Configure the EDS acquisition settings."""
    settings = EdsAcquisitionSettings()
    settings.eds_detector = eds_detector_name
    settings.dispersion = dispersion
    settings.shaping_time = shaping_time
    settings.exposure_time = exposure_time
    settings.exposure_time_type = exposure_time_type
    return settings


def acquire_and_plot_spectrum(path, image_data, microscope, settings):
    """Acquire EDS spectrum and plot the image and spectrum side by side."""
    
    # Acquire the EDS spectrum
    spectrum = microscope.analysis.acquire_eds_spectrum(settings)
    dt = np.dtype('uint32').newbyteorder('<')
    spec = np.frombuffer(spectrum._raw_data, dtype=dt)
    
    # Get the current beam position
    position = microscope.optics.paused_scan_beam_position
    x = position.x
    y = position.y
    formatted_position = f"({x:.2g}, {y:.2g})"

    # Create side-by-side plots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Plot the image data
    axs[0].imshow(image_data, cmap='gray')
    axs[0].set_title('Acquired Image')
    axs[0].set_axis_off()  # Hide axes for the image plot
    axs[0].scatter(x * 256, y * 256, c='r', s=100, marker='x', label=f"Position: {formatted_position}")

    # Plot the EDS spectrum
    axs[1].plot(np.arange(len(spec)) * 5 / 1000, spec, label="Position: " + formatted_position)
    axs[1].set_title('Acquired EDS Spectrum')
    axs[1].set_xlabel('Channel (KeV)')
    axs[1].set_ylabel('Intensity Counts')
    axs[1].legend()

    # Adjust layout and display
    plt.tight_layout()
    plt.savefig(path)
    st.image(path, caption='Acquired Image and Spectrum', use_column_width=True)
    plt.close()
    return spec
    
@st.cache_data
def plot_and_save_image(image_data, path, caption, use_column_width=True):
    #plt.imshow(image_data, origin='lower')---------> be aware how flippin ghte origin
    plt.imshow(image_data)
    plt.colorbar()
    plt.savefig(path)
    plt.close()
    st.image(path, caption=caption, use_column_width=use_column_width)
    logging.info(f"{caption} saved to {path}")

def main():
    # Set up page title and description
    st.title("h(human)AE interface")
    st.image("app/static/hAE.png", use_column_width=True)
    st.markdown("---")
    st.markdown("### Configure the parameters for the experiment")

    # Organize the input fields in two columns
    col1, col2 = st.columns(2)

    with col1:
        save_dir = st.text_input('Save Directory', st.session_state.save_dir)
        data_path = st.text_input('Data Path', st.session_state.data_path)

    with col2:
        window_size = st.number_input('Window Size', value=st.session_state.window_size, step=1)
        budget = st.number_input('Budget', value=st.session_state.budget, step=1)
    
    
    HAADF_exposure = st.number_input('HAADF Exposure', value=st.session_state.HAADF_exposure, step=1e-5)
    HAADF_resolution = st.number_input('HAADF Resolution', value=st.session_state.HAADF_resolution, step=1)
    
    # Update session state with current values
    st.session_state.save_dir = save_dir
    st.session_state.data_path = data_path
    st.session_state.window_size = window_size
    st.session_state.budget = budget
    st.session_state.HAADF_exposure = HAADF_exposure
    st.session_state.HAADF_resolution = HAADF_resolution
    
    # Based on user selection, call the appropriate function

    option = st.selectbox(
        'Choose the scalarizer function',
        (
            'scalarizer_sum',
            'scalarizer_max',
            'scalarizer_mean',
            'scalarizer_inv_mean'
        )
    )


    if option == 'scalarizer_sum':
        sc_f = scalarizer_sum
    elif option == 'scalarizer_max':
        sc_f = scalarizer_max
    elif option == 'scalarizer_mean':
        sc_f = scalarizer_mean
    elif option == 'scalarizer_inv_mean':
        sc_f = scalarizer_inv_mean


    # Display selected parameters at the top
    st.markdown(f"### Selected Parameters")
    st.markdown(f"**Budget:** {budget}, **Window Size:** {window_size}, **Scalarizer Function:** {option}, **HAADF Exposure:** {HAADF_exposure}, **HAADF Resolution:** {HAADF_resolution}")


    # Add a button to start the process
    start_process = st.button('Start Process')

    # Optional: Add a decorative horizontal line for better separation of content
    st.markdown("---")

        
    
    if start_process:
        st.success("Process started with the current configuration!")
        config = {
            'settings': {
                'save_dir': save_dir,
                'data_path': data_path,
                'window_size': window_size,
                'budget': budget,
                'HAADF_exposure': HAADF_exposure,
                'HAADF_resolution': HAADF_resolution
            }
        }
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Initialize logging
        init_logging(save_dir=save_dir, config=config)
        logging.info(f"Directory {save_dir} already exists.")
        
        microscope = TemMicroscopeClient()
        ip = ""#--> velo computer
        # ip = ""# gatan computer
        # microscope.connect(ip, 9095)# 7521 on gatan computer
        microscope.connect(ip, port = 9095)# 7521 on velox  computer
        image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, HAADF_resolution, HAADF_exposure)
        img = image.data - np.min(image.data)
        image_data = (255*(img/np.max(img))).astype(np.uint8)

        
        img, features, indices_all = load_image_and_features(image_data, window_size)## cached so not run again
        
        survey_img_path = os.path.join(save_dir, "survey_image.png")
        plot_and_save_image(image_data, survey_img_path, 'Survey Image')
        
        st.write(f"Features shape: {features.shape}")
        logging.info(f"Features shape: {features.shape}")
        # Streamlit dropdown for function selection

        n, d1, d2 = features.shape
        X = features.reshape(n, d1*d2)
    
        # use only 0.02% of grid data points as initial training points
        (X_measured, X_unmeasured, indices_measured, indices_unmeasured) = train_test_split(
            X, indices_all, train_size=10, shuffle=True, random_state=3)

        seed_points = len(X_measured)

        print("Seed points: ", seed_points)
    


        y_measured_unnor = []
        for i in range (seed_points):
            # converted_coords = newexp.convert_coordinates(indices_measured[i,::-1], num_pix_x = 192, num_pix_y = 192)
            # wv, quick_fit, cx, chns = newexp.do_beps_specific(coordinates = converted_coords,
            #                                                 file_name = "BEPS_seed_{}".format(i))
            # amp_off, amp_on = sep_on_off(np.asarray(quick_fit)[0,0,])
            # pha_off, pha_on = sep_on_off(np.asarray(quick_fit)[0,3,])
            # current_y = loop_area(amp_off*np.cos(pha_off), 3)
            # Initialize the microscope and detector settings
            eds_detector_name = EdsDetectorType.SUPER_X
            # Run acquisitions for gold and vacuum positions
            path_spec = os.path.join(save_dir, f"seed{i}.png")
            spectrum = run_acquisition_for_position(path_spec, image_data, microscope, indices_measured[i]/HAADF_resolution, eds_detector_name)
            # choose mean as the scalarizer function
            # print(spectrum)
            current_y = sc_f(spectrum)
            y_measured_unnor.append(current_y)
            # time.sleep(0.5)
            
        norm_ = lambda x: (x - x.min()) / x.ptp()
        y_measured = norm_(np.asarray(y_measured_unnor))
        
        # np.savez("seed_data.npz", X_measured = X_measured, X_unmeasured=X_unmeasured, 
        #         y_measured = y_measured, y_measured_unnor = y_measured_unnor,
        #         indices_measured=indices_measured, indices_unmeasured = indices_unmeasured)
        
        
        
        #plt.imshow(img, origin='lower')
        plt.imshow(img)
        plt.title('Seed Points')
        plt.scatter(indices_measured[:, 1], indices_measured[:, 0], c="r", marker = "X", s=30)
        plt.colorbar()
        seed_img_path = os.path.join(save_dir, "seed_points.png")
        plt.savefig(seed_img_path)
        plt.close()
        st.image(seed_img_path, caption='Seed Points', use_column_width=True)
        logging.info(f"Seed points saved to {seed_img_path}")
        
        
        # Save initial data to session state
        st.session_state.y_measured_unnor = y_measured_unnor
        st.session_state.X_measured = X_measured
        st.session_state.X_unmeasured = X_unmeasured
        st.session_state.y_measured = y_measured
        # st.session_state.y_unmeasured = y_unmeasured
        st.session_state.indices_measured = indices_measured
        st.session_state.indices_unmeasured = indices_unmeasured
        st.session_state.img = img
        st.session_state.save_dir = save_dir
        st.session_state.seed_points = seed_points
        st.session_state.budget = budget
        st.session_state.microscope = microscope
        st.session_state.seed_img_path = seed_img_path
        st.session_state.image_data = image_data
        st.session_state.sc_f = sc_f



    if st.session_state.X_measured is not None:
        y_measured_unnor = st.session_state.y_measured_unnor 
        X_measured = st.session_state.X_measured
        X_unmeasured = st.session_state.X_unmeasured
        y_measured = st.session_state.y_measured
        # y_unmeasured = st.session_state.y_unmeasured
        indices_measured = st.session_state.indices_measured
        indices_unmeasured = st.session_state.indices_unmeasured
        img = st.session_state.img
        save_dir = st.session_state.save_dir
        seed_points = st.session_state.seed_points
        budget = st.session_state.budget
        # microscope = st.session_state.microscope
        seed_img_path = st.session_state.seed_img_path 
        image_data = st.session_state.image_data
        sc_f = st.session_state.sc_f 

        microscope = TemMicroscopeClient()
        ip = ""#--> velo computer
        # ip = ""# gatan computer
        # microscope.connect(ip, 9095)# 7521 on gatan computer
        microscope.connect(ip, port = 9095)# 7521 on velox  computer



        norm_ = lambda x: (x - x.min()) / x.ptp()
        
        def plot_result_dkl(indices, obj):
            fig, ax = plt.subplots(1, 1, figsize=(3, 3))
            scatter = ax.scatter(indices[:, 1], indices[:, 0], s=32, c=obj, marker='s')
            next_point = indices[obj.argmax()]  # maximize the acquisition function
            ax.scatter(next_point[1], next_point[0], marker='x', c='k')
            ax.set_title("Acquisition function values")
            fig.colorbar(scatter, ax=ax, label='Objective Value')
            return fig
        
        data_dim = X_measured.shape[-1]
        key1, key2 = gpax.utils.get_keys()

        if st.session_state.exploration_step < budget:## before the budget is reached
            e = st.session_state.exploration_step
            st.write("{}/{}".format(e + 1, budget))

            
            st.markdown(f"### Policy intervention")
            # Ensure the session state for beta is initialized at the beginning or reset
            if 'beta' not in st.session_state:
                st.session_state.beta = 0.25  # Default initialization

            # Allow user to modify beta only if they haven't proceeded yet
            if 'proceed' not in st.session_state or not st.session_state.proceed:#### wont run if proceed is in session state
                beta_options = [0.25, 0.50, 0.75, 1.0]
                st.session_state.beta = st.selectbox(
                    'Beta Value for Step', 
                    options=beta_options, 
                    index=beta_options.index(st.session_state.beta)  # Ensures current session value is selected
                )
                
            beta = st.session_state.beta
            
            
            # Make sure the session state for steps_with_same_beta is initialized at the beginning or reset
            if 'steps_with_same_beta' not in st.session_state:
                st.session_state.steps_with_same_beta = 1  # Default initialization

            # Allow user to modify steps_with_same_beta only if they haven't proceeded yet
            if 'proceed' not in st.session_state or not st.session_state.proceed:#### wont run if proceed in session state
                st.session_state.steps_with_same_beta = st.number_input(
                    'Number of Steps with Same Beta', 
                    min_value=1, 
                    max_value=st.session_state.budget - st.session_state.exploration_step, 
                    value=st.session_state.steps_with_same_beta,
                    step=1
    )
            # st.markdown(f"### Scalarizer intervention to be added here")
            proceed = st.button('Proceed to Next Step')
            # Store the button press in session state to lock the input
            if proceed:
                st.session_state.proceed = True  # This flag prevents further changes to steps_with_same_beta

            st.image(seed_img_path, caption='Seed Points', use_column_width=True)


            # Logic to execute after pressing the 'Proceed' button
            if 'proceed' in st.session_state and st.session_state.proceed:
                del st.session_state['proceed']#-----------------------> important to delete the key
                for _ in range(st.session_state.steps_with_same_beta):
                    print(budget)
                    if st.session_state.exploration_step >= budget:# if true it breaks
                        break
                    exploration_step = st.session_state.exploration_step
                    dkl = gpax.viDKL(data_dim, 2)
                    # import pdb; pdb.set_trace()
                    dkl.fit(key1, np.array(X_measured), np.array(y_measured), num_steps=100, step_size=0.05)
                    obj = gpax.acquisition.UCB(key2, dkl, np.array(X_unmeasured), beta=beta, maximize=True)
                    next_point_idx = obj.argmax()
                    eds_detector_name = EdsDetectorType.SUPER_X
                    # Run acquisitions for gold and vacuum positions
                    path_spec = os.path.join(save_dir, f"active_learning{exploration_step}.png")
                    spectrum = run_acquisition_for_position(path_spec, image_data, microscope, indices_unmeasured[next_point_idx]/HAADF_resolution, eds_detector_name)
                    # choose mean as the scalarizer function
                    # print(spectrum)
                    measured_point = sc_f(spectrum)
                    
                    fig = plot_result_dkl(np.array(indices_unmeasured), obj)
                    acq_func_img_path = os.path.join(save_dir, f"acquisition_function_{st.session_state.exploration_step}.png")
                    plt.savefig(acq_func_img_path)
                    plt.close()
                    st.image(acq_func_img_path, caption=f'Acquisition Function Plot {st.session_state.exploration_step + 1}', use_column_width=True)
                    logging.info(f"Acquisition function plot saved to {acq_func_img_path}")
                    
                    X_measured = np.append(np.array(X_measured), np.array(X_unmeasured)[next_point_idx][None], 0)
                    X_unmeasured = np.delete(np.array(X_unmeasured), next_point_idx, 0)
                    y_measured_unnor = np.append(np.array(y_measured_unnor), measured_point)
                    y_measured = norm_(np.asarray(y_measured_unnor))
                    # y_unmeasured = np.delete(np.array(y_unmeasured), next_point_idx)
                    indices_measured = np.append(np.array(indices_measured), np.array(indices_unmeasured)[next_point_idx][None], 0)
                    indices_unmeasured = np.delete(np.array(indices_unmeasured), next_point_idx, 0)
                    
                    st.session_state.X_measured = X_measured
                    st.session_state.X_unmeasured = X_unmeasured
                    st.session_state.y_measured = y_measured
                    #st.session_state.y_unmeasured = y_unmeasured
                    st.session_state.y_measured_unnor = y_measured_unnor
                    st.session_state.indices_measured = indices_measured
                    st.session_state.indices_unmeasured = indices_unmeasured
                    
                    st.session_state.exploration_step += 1

        if st.session_state.exploration_step < budget:####### before the budget is reached
            #plt.imshow(np.array(img), origin="lower", cmap='gray')
            plt.imshow(np.array(img), cmap='gray')
            plt.scatter(
                np.array(indices_measured)[seed_points:, 1],
                np.array(indices_measured)[seed_points:, 0],
                c=np.arange(len(np.array(indices_measured)[seed_points:])),
                s=50,
                cmap="Reds"
            )
            plt.colorbar()
            final_points_img_path = os.path.join(save_dir, "intermediate_points.png")
            plt.savefig(final_points_img_path)
            plt.close()
            st.image(final_points_img_path, caption='Points sampled till now', use_column_width=True)
            logging.info(f"intermediate points saved to {final_points_img_path}")

        if st.session_state.exploration_step >= budget:####### after the budget is reached
            # plt.imshow(np.array(img), origin="lower", cmap='gray')
            plt.imshow(np.array(img), cmap='gray')
            plt.scatter(
                np.array(indices_measured)[seed_points:, 1],
                np.array(indices_measured)[seed_points:, 0],
                c=np.arange(len(np.array(indices_measured)[seed_points:])),
                s=50,
                cmap="Reds"
            )
            plt.colorbar()
            final_points_img_path = os.path.join(save_dir, "final_points.png")
            plt.savefig(final_points_img_path)
            plt.close()
            st.image(final_points_img_path, caption='Final Points', use_column_width=True)
            logging.info(f"Final points saved to {final_points_img_path}")

if __name__ == "__main__":
    main()