import streamlit as st
from src.logger_n_config import init_logging, load_config, parse_arguments
from src.mic_sim import Microscope
import logging
from atomai.utils import get_coord_grid, extract_patches_and_spectra
from src.scalarizer_preacquired import (
    scalarizer_peak_hard_coded,
    scalarizer_sum,
    scalarizer_max,
    scalarizer_mean,
    scalarizer_inv_mean,)
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import numpy as np
import gpax

# Initialize session state variables
if "exploration_step" not in st.session_state:
    st.session_state.exploration_step = 0
    st.session_state.X_measured = None
    st.session_state.X_unmeasured = None
    st.session_state.y_measured = None
    st.session_state.y_unmeasured = None
    st.session_state.indices_measured = None
    st.session_state.indices_unmeasured = None
    st.session_state.img = None
    st.session_state.save_dir = 'out_data_sl/'  # Default value for save_dir
    st.session_state.steps_with_same_beta = 1  # Default steps with same beta
    st.session_state.seed_points = None
    st.session_state.budget = 20  # Default value for budget
    st.session_state.spectralavg = 4  # Default value for spectralavg
    st.session_state.data_path = 'in_data_sl/Plasmonic_EELS_FITO0_edgehole_01.npy'  # Default value for data_path
    st.session_state.window_size = 16  # Default value for window_size


@st.cache_data
def load_image_and_features(data_path, window_size, spectralavg):
    mic = Microscope(data_path)
    img = mic.survey_img
    specim = mic.spec_img
    coordinates = get_coord_grid(img, step=1, return_dict=False)
    features, targets, indices = extract_patches_and_spectra(
        specim, img, coordinates=coordinates, window_size=window_size, avg_pool=spectralavg
    )
    norm_ = lambda x: (x - x.min()) / x.ptp()
    features, targets = norm_(features), norm_(targets)
    return img, features, targets, indices

@st.cache_data
def plot_and_save_image(image, path, caption, use_column_width=True):
    plt.imshow(image, origin='lower')
    plt.colorbar()
    plt.savefig(path)
    plt.close()
    st.image(path, caption=caption, use_column_width=use_column_width)
    logging.info(f"{caption} saved to {path}")

def main():
    # Set up page title and description
    st.title("h(human)AE interface")
    st.image("app/static/hAE.png", use_column_width=True)
    st.markdown("---")
    st.markdown("### Configure the parameters for the experiment")

    # Organize the input fields in two columns
    col1, col2 = st.columns(2)

    with col1:
        save_dir = st.text_input('Save Directory', st.session_state.save_dir)
        data_path = st.text_input('Data Path', st.session_state.data_path)

    with col2:
        window_size = st.number_input('Window Size', value=st.session_state.window_size, step=1)
        budget = st.number_input('Budget', value=st.session_state.budget, step=1)
    
    spectralavg = st.number_input('Spectral Average', value = st.session_state.spectralavg, step=1)

    # Update session state with current values
    st.session_state.save_dir = save_dir
    st.session_state.data_path = data_path
    st.session_state.window_size = window_size
    st.session_state.budget = budget
    st.session_state.spectralavg = spectralavg
    # Based on user selection, call the appropriate function

    option = st.selectbox(
        'Choose the scalarizer function',
        (
            'scalarizer_peak', 
            'scalarizer_sum',
            'scalarizer_max',
            'scalarizer_mean',
            'scalarizer_inv_mean'
        )
    )

    if option == 'scalarizer_peak':
        sc_f = scalarizer_peak_hard_coded
    elif option == 'scalarizer_sum':
        sc_f = scalarizer_sum
    elif option == 'scalarizer_max':
        sc_f = scalarizer_max
    elif option == 'scalarizer_mean':
        sc_f = scalarizer_mean
    elif option == 'scalarizer_inv_mean':
        sc_f = scalarizer_inv_mean


    # Display selected parameters at the top
    st.markdown(f"### Selected Parameters")
    st.markdown(f"**Budget:** {budget}, **Window Size:** {window_size}, **Spectral Average:** {spectralavg}, **Scalarizer Function:** {option}")


    # Add a button to start the process
    start_process = st.button('Start Process')

    # Optional: Add a decorative horizontal line for better separation of content
    st.markdown("---")

        
    
    if start_process:
        st.success("Process started with the current configuration!")
        config = {
            'settings': {
                'spectralavg': spectralavg,
                'save_dir': save_dir,
                'data_path': data_path,
                'window_size': window_size,
                'budget': budget
            }
        }
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Initialize logging
        init_logging(save_dir=save_dir, config=config)
        logging.info(f"Directory {save_dir} already exists.")
        
        img, features, targets, indices = load_image_and_features(data_path, window_size, spectralavg)## cached so not run again
        
        survey_img_path = os.path.join(save_dir, "survey_image.png")
        plot_and_save_image(img, survey_img_path, 'Survey Image')
        
        st.write(f"Features shape: {features.shape}, Targets shape: {targets.shape}")
        logging.info(f"Features shape: {features.shape}, Targets shape: {targets.shape}")
        # Streamlit dropdown for function selection



        
        peaks_all, features_all, indices_all = sc_f(targets, features, indices)

        plt.scatter(indices_all[:, 1], indices_all[:, 0], c=peaks_all)
        plt.title('Scalarizer values')
        plt.gca().set_aspect('equal')
        plt.colorbar()
        plasmon_img_path = os.path.join(save_dir, "plasmon_peak_intensities.png")
        plt.savefig(plasmon_img_path)
        plt.close()
        st.image(plasmon_img_path, caption='Plasmon Peak Intensities', use_column_width=True)
        logging.info(f"Plasmon peak intensities saved to {plasmon_img_path}")
        
        logging.info(f"Peaks_all shape: {peaks_all.shape}, Features_all shape: {features_all.shape}, Indices_all shape: {indices_all.shape}")
        
        n, d1, d2 = features_all.shape
        X = features_all.reshape(n, d1 * d2)
        y = peaks_all
        st.write(f"X shape: {X.shape}, y shape: {y.shape}")
        logging.info(f"X shape: {X.shape}, y shape: {y.shape}")
        
        (
            X_measured,
            X_unmeasured,
            y_measured,
            y_unmeasured,
            indices_measured,
            indices_unmeasured
        ) = train_test_split(
            X,
            y,
            indices_all,
            test_size=0.990,
            shuffle=True,
            random_state=1
        )
        
        seed_points = len(X_measured)
        
        plt.imshow(img, origin='lower')
        plt.title('Seed Points')
        plt.scatter(indices_measured[:, 1], indices_measured[:, 0], c="r", marker = "X", s=30)
        plt.colorbar()
        seed_img_path = os.path.join(save_dir, "seed_points.png")
        plt.savefig(seed_img_path)
        plt.close()
        st.image(seed_img_path, caption='Seed Points', use_column_width=True)
        logging.info(f"Seed points saved to {seed_img_path}")
        
        # Save initial data to session state
        st.session_state.X_measured = X_measured
        st.session_state.X_unmeasured = X_unmeasured
        st.session_state.y_measured = y_measured
        st.session_state.y_unmeasured = y_unmeasured
        st.session_state.indices_measured = indices_measured
        st.session_state.indices_unmeasured = indices_unmeasured
        st.session_state.img = img
        st.session_state.save_dir = save_dir
        st.session_state.seed_points = seed_points
        st.session_state.budget = budget
        st.session_state.seed_img_path = seed_img_path

    if st.session_state.X_measured is not None:
        X_measured = st.session_state.X_measured
        X_unmeasured = st.session_state.X_unmeasured
        y_measured = st.session_state.y_measured
        y_unmeasured = st.session_state.y_unmeasured
        indices_measured = st.session_state.indices_measured
        indices_unmeasured = st.session_state.indices_unmeasured
        img = st.session_state.img
        save_dir = st.session_state.save_dir
        seed_points = st.session_state.seed_points
        budget = st.session_state.budget
        seed_img_path = st.session_state.seed_img_path 


        def plot_result_dkl(indices, obj):
            fig, ax = plt.subplots(1, 1, figsize=(3, 3))
            scatter = ax.scatter(indices[:, 1], indices[:, 0], s=32, c=obj, marker='s')
            next_point = indices[obj.argmax()]  # maximize the acquisition function
            ax.scatter(next_point[1], next_point[0], marker='x', c='k')
            ax.set_title("Acquisition function values")
            fig.colorbar(scatter, ax=ax, label='Objective Value')
            return fig
        
        data_dim = X_measured.shape[-1]
        key1, key2 = gpax.utils.get_keys()

        if st.session_state.exploration_step < budget:## before the budget is reached
            e = st.session_state.exploration_step
            st.write("{}/{}".format(e + 1, budget))

            
            st.markdown(f"### Policy intervention")
            # Ensure the session state for beta is initialized at the beginning or reset
            if 'beta' not in st.session_state:
                st.session_state.beta = 0.25  # Default initialization

            # Allow user to modify beta only if they haven't proceeded yet
            if 'proceed' not in st.session_state or not st.session_state.proceed:#### wont run if proceed is in session state
                beta_options = [0.25, 0.50, 0.75, 1.0]
                st.session_state.beta = st.selectbox(
                    'Beta Value for Step', 
                    options=beta_options, 
                    index=beta_options.index(st.session_state.beta)  # Ensures current session value is selected
                )
                
            beta = st.session_state.beta
            
            
            # Make sure the session state for steps_with_same_beta is initialized at the beginning or reset
            if 'steps_with_same_beta' not in st.session_state:
                st.session_state.steps_with_same_beta = 1  # Default initialization

            # Allow user to modify steps_with_same_beta only if they haven't proceeded yet
            if 'proceed' not in st.session_state or not st.session_state.proceed:#### wont run if proceed in session state
                st.session_state.steps_with_same_beta = st.number_input(
                    'Number of Steps with Same Beta', 
                    min_value=1, 
                    max_value=st.session_state.budget - st.session_state.exploration_step, 
                    value=st.session_state.steps_with_same_beta,
                    step=1
    )
            # st.markdown(f"### Scalarizer intervention to be added here")
            proceed = st.button('Proceed to Next Step')
            # Store the button press in session state to lock the input
            if proceed:
                st.session_state.proceed = True  # This flag prevents further changes to steps_with_same_beta


            st.image(seed_img_path, caption='Seed Points', use_column_width=True)


            # Logic to execute after pressing the 'Proceed' button
            if 'proceed' in st.session_state and st.session_state.proceed:
                del st.session_state['proceed']#-----------------------> important to delete the key
                for _ in range(st.session_state.steps_with_same_beta):
                    print(budget)
                    if st.session_state.exploration_step >= budget:# if true it breaks
                        break
                    dkl = gpax.viDKL(data_dim, 2)
                    dkl.fit(key1, np.array(X_measured), np.array(y_measured), num_steps=100, step_size=0.05)
                    obj = gpax.acquisition.UCB(key2, dkl, np.array(X_unmeasured), beta=beta, maximize=True)
                    next_point_idx = obj.argmax()
                    measured_point = np.array(y_unmeasured)[next_point_idx]
                    
                    fig = plot_result_dkl(np.array(indices_unmeasured), obj)
                    acq_func_img_path = os.path.join(save_dir, f"acquisition_function_{st.session_state.exploration_step}.png")
                    plt.savefig(acq_func_img_path)
                    plt.close()
                    st.image(acq_func_img_path, caption=f'Acquisition Function Plot {st.session_state.exploration_step + 1}', use_column_width=True)
                    logging.info(f"Acquisition function plot saved to {acq_func_img_path}")
                    
                    X_measured = np.append(np.array(X_measured), np.array(X_unmeasured)[next_point_idx][None], 0)
                    X_unmeasured = np.delete(np.array(X_unmeasured), next_point_idx, 0)
                    y_measured = np.append(np.array(y_measured), measured_point)
                    y_unmeasured = np.delete(np.array(y_unmeasured), next_point_idx)
                    indices_measured = np.append(np.array(indices_measured), np.array(indices_unmeasured)[next_point_idx][None], 0)
                    indices_unmeasured = np.delete(np.array(indices_unmeasured), next_point_idx, 0)
                    
                    st.session_state.X_measured = X_measured
                    st.session_state.X_unmeasured = X_unmeasured
                    st.session_state.y_measured = y_measured
                    st.session_state.y_unmeasured = y_unmeasured
                    st.session_state.indices_measured = indices_measured
                    st.session_state.indices_unmeasured = indices_unmeasured
                    
                    st.session_state.exploration_step += 1

        if st.session_state.exploration_step < budget:####### before the budget is reached
            plt.imshow(np.array(img), origin="lower", cmap='gray')
            plt.scatter(
                np.array(indices_measured)[seed_points:, 1],
                np.array(indices_measured)[seed_points:, 0],
                c=np.arange(len(np.array(indices_measured)[seed_points:])),
                s=50,
                cmap="Reds"
            )
            plt.colorbar()
            final_points_img_path = os.path.join(save_dir, "intermediate_points.png")
            plt.savefig(final_points_img_path)
            plt.close()
            st.image(final_points_img_path, caption='Points sampled till now', use_column_width=True)
            logging.info(f"intermediate points saved to {final_points_img_path}")

        if st.session_state.exploration_step >= budget:####### after the budget is reached
            plt.imshow(np.array(img), origin="lower", cmap='gray')
            plt.scatter(
                np.array(indices_measured)[seed_points:, 1],
                np.array(indices_measured)[seed_points:, 0],
                c=np.arange(len(np.array(indices_measured)[seed_points:])),
                s=50,
                cmap="Reds"
            )
            plt.colorbar()
            final_points_img_path = os.path.join(save_dir, "final_points.png")
            plt.savefig(final_points_img_path)
            plt.close()
            st.image(final_points_img_path, caption='Final Points', use_column_width=True)
            logging.info(f"Final points saved to {final_points_img_path}")

if __name__ == "__main__":
    main()