import streamlit as st
from src.logger_n_config import init_logging, load_config, parse_arguments
from src.mic_sim import Microscope
import logging
from atomai.utils import get_coord_grid, extract_patches_and_spectra, extract_subimages
from src.scalarizer_live import (
    scalarizer_sum,
    scalarizer_max,
    scalarizer_mean,
    scalarizer_inv_mean,)
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import numpy as np
import gpax
import Pyro5.api

import time

######################################### Autoscript related functions #########################################

from autoscript_tem_microscope_client import TemMicroscopeClient
from autoscript_tem_microscope_client.enumerations import *
from autoscript_tem_microscope_client.structures import *
import numpy as np
# General packages
import os, time, sys, math

# General image processing packages
from matplotlib import pyplot as plot
import numpy as np
import cv2 as cv

#########################################

# Create microscope autoscript client and connect to the HPC




# def fake_get_spectrum(x, y):
#     return np.random.rand(100)

        

# Initialize session state variables
if "exploration_step" not in st.session_state:
    st.session_state.exploration_step = 0
    st.session_state.X_measured = None
    st.session_state.X_unmeasured = None
    st.session_state.y_measured = None
    st.session_state.y_unmeasured = None
    st.session_state.indices_measured = None
    st.session_state.indices_unmeasured = None
    st.session_state.img = None
    st.session_state.save_dir = 'out_data_sl/'  # Default value for save_dir
    st.session_state.steps_with_same_beta = 1  # Default steps with same beta
    st.session_state.seed_points = None
    st.session_state.budget = 20  # Default value for budget
    st.session_state.HAADF_exposure = 4e-5  # Default value for spectralavg
    st.session_state.HAADF_resolution = 128  # Default value for resolution
    st.session_state.data_path = 'in_data_sl/Plasmonic_EELS_FITO0_edgehole_01.npy'  # Default value for data_path
    st.session_state.window_size = 16  # Default value for window_size
    st.session_state.seed_points = 30  # Default value for seed_points


@st.cache_data
def load_image_and_features(img, window_size):
    coordinates = get_coord_grid(img, step = 1, return_dict=False)
    features_all, coords, _ = extract_subimages(img, coordinates, window_size)
    features_all = features_all[:,:,:,0]
    coords = np.array(coords, dtype=int)
    norm_ = lambda x: (x - x.min()) / x.ptp()
    features = norm_(features_all)
    return img, features, coords

@st.cache_data
def plot_and_save_image(image, path, caption, use_column_width=True):
    plt.imshow(image, origin='lower')
    plt.colorbar()
    plt.savefig(path)
    plt.close()
    st.image(path, caption=caption, use_column_width=use_column_width)
    logging.info(f"{caption} saved to {path}")

def main():
    # Set up page title and description
    st.title("h(human)AE interface")
    st.image("app/static/hAE.png", use_column_width=True)
    st.markdown("---")
    st.markdown("### Configure the parameters for the experiment")

    # Organize the input fields in two columns
    col1, col2 = st.columns(2)

    with col1:
        save_dir = st.text_input('Save Directory', st.session_state.save_dir)
        data_path = st.text_input('Data Path', st.session_state.data_path)

    with col2:
        window_size = st.number_input('Window Size', value=st.session_state.window_size, step=1)
        budget = st.number_input('Budget', value=st.session_state.budget, step=1)
    
    
    HAADF_exposure = st.number_input('HAADF Exposure', value=st.session_state.HAADF_exposure, step=1e-5)
    HAADF_resolution = st.number_input('HAADF Resolution', value=st.session_state.HAADF_resolution, step=1)
    seed_points = st.number_input('Seed Points', value=st.session_state.seed_points, step=1)
    # Update session state with current values
    st.session_state.save_dir = save_dir
    st.session_state.data_path = data_path
    st.session_state.window_size = window_size
    st.session_state.budget = budget
    st.session_state.HAADF_exposure = HAADF_exposure
    st.session_state.HAADF_resolution = HAADF_resolution
    # Based on user selection, call the appropriate function

    option = st.selectbox(
        'Choose the scalarizer function',
        (
            'scalarizer_sum',
            'scalarizer_max',
            'scalarizer_mean',
            'scalarizer_inv_mean'
        )
    )


    if option == 'scalarizer_sum':
        sc_f = scalarizer_sum
    elif option == 'scalarizer_max':
        sc_f = scalarizer_max
    elif option == 'scalarizer_mean':
        sc_f = scalarizer_mean
    elif option == 'scalarizer_inv_mean':
        sc_f = scalarizer_inv_mean


    # Display selected parameters at the top
    st.markdown(f"### Selected Parameters")
    st.markdown(f"**Budget:** {budget}, **Window Size:** {window_size}, **Scalarizer Function:** {option}, **HAADF Exposure:** {HAADF_exposure}, **HAADF Resolution:** {HAADF_resolution}, **Seed Points:** {seed_points}")


    # Add a button to start the process
    start_process = st.button('Start Process')

    # Optional: Add a decorative horizontal line for better separation of content
    st.markdown("---")

        
    
    if start_process:
        st.success("Process started with the current configuration!")
        config = {
            'settings': {
                'save_dir': save_dir,
                'data_path': data_path,
                'window_size': window_size,
                'budget': budget,
                'HAADF_exposure': HAADF_exposure,
                'HAADF_resolution': HAADF_resolution,
                'seed_points': seed_points
            }
        }
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Initialize logging
        init_logging(save_dir=save_dir, config=config)
        logging.info(f"Directory {save_dir} already exists.")
        
        norm_ = lambda x: (x - x.min()) / x.ptp()
        
        uri = ""
        microscope = Pyro5.api.Proxy(uri)
        microscope.activate_camera()
        array_list, shape, dtype = microscope.get_ds(HAADF_resolution, HAADF_resolution)
        image_data = norm_(np.array(array_list, dtype=dtype).reshape(shape))    
        img, features, indices_all = load_image_and_features(image_data, window_size)## cached so not run again
        
        survey_img_path = os.path.join(save_dir, "survey_image.png")
        plot_and_save_image(img, survey_img_path, 'Survey Image')
        
        st.write(f"Features shape: {features.shape}")
        logging.info(f"Features shape: {features.shape}")
        # Streamlit dropdown for function selection

        n, d1, d2 = features.shape
        X = features.reshape(n, d1*d2)
    
        # use only 0.02% of grid data points as initial training points
        (X_measured, X_unmeasured, indices_measured, indices_unmeasured) = train_test_split(
            X, indices_all, train_size=seed_points, shuffle=True, random_state=3)

        seed_points = len(X_measured)

        print("Seed points: ", seed_points)
        st.write(f"Seed points: {seed_points}")
        # Countdown from 10 to 0
        # for seconds in range(10, -1, -1):
        #     st.write(f"Time remaining: {seconds} seconds")
        #     time.sleep(1)
        #     st.empty()  # Clears the line to update the countdown dynamically
        
        y_measured_unnor = []  
        for i in range (seed_points):
            # converted_coords = newexp.convert_coordinates(indices_measured[i,::-1], num_pix_x = 192, num_pix_y = 192)
            # wv, quick_fit, cx, chns = newexp.do_beps_specific(coordinates = converted_coords,
            #                                                 file_name = "BEPS_seed_{}".format(i))
            # amp_off, amp_on = sep_on_off(np.asarray(quick_fit)[0,0,])
            # pha_off, pha_on = sep_on_off(np.asarray(quick_fit)[0,3,])
            # current_y = loop_area(amp_off*np.cos(pha_off), 3)
            print ("ith measuring at ",i,  indices_measured[i, 1], indices_measured[i, 0])
            microscope.set_beam_pos(int(indices_measured[i, 1]), int(indices_measured[i, 0]))
            microscope.acquire_camera(exposure=0.1)
            array_list, shape, dtype = microscope.get_eels()
            spectrum = np.array(array_list, dtype=dtype).reshape(shape)
            # choose mean as the scalarizer function
            print(spectrum)
            current_y = sc_f(spectrum)
            y_measured_unnor.append(current_y)
            # time.sleep(0.5)
            
        norm_ = lambda x: (x - x.min()) / x.ptp()
        y_measured = norm_(np.asarray(y_measured_unnor))
        
        # np.savez("seed_data.npz", X_measured = X_measured, X_unmeasured=X_unmeasured, 
        #         y_measured = y_measured, y_measured_unnor = y_measured_unnor,
        #         indices_measured=indices_measured, indices_unmeasured = indices_unmeasured)
        
        
        
        plt.imshow(img, origin='lower')
        plt.title('Seed Points')
        plt.scatter(indices_measured[:, 1], indices_measured[:, 0], c="r", marker = "X", s=30)
        plt.colorbar()
        seed_img_path = os.path.join(save_dir, "seed_points.png")
        plt.savefig(seed_img_path)
        plt.close()
        st.image(seed_img_path, caption='Seed Points', use_column_width=True)
        logging.info(f"Seed points saved to {seed_img_path}")
        
        
        # Save initial data to session state
        st.session_state.y_measured_unnor = y_measured_unnor
        st.session_state.X_measured = X_measured
        st.session_state.X_unmeasured = X_unmeasured
        st.session_state.y_measured = y_measured
        # st.session_state.y_unmeasured = y_unmeasured
        st.session_state.indices_measured = indices_measured
        st.session_state.indices_unmeasured = indices_unmeasured
        st.session_state.img = img
        st.session_state.save_dir = save_dir
        st.session_state.seed_points = seed_points
        st.session_state.budget = budget
        st.session_state.microscope = microscope
        st.session_state.seed_img_path = seed_img_path
        st.session_state.sc_f = sc_f


    if st.session_state.X_measured is not None:
        y_measured_unnor = st.session_state.y_measured_unnor 
        X_measured = st.session_state.X_measured
        X_unmeasured = st.session_state.X_unmeasured
        y_measured = st.session_state.y_measured
        # y_unmeasured = st.session_state.y_unmeasured
        indices_measured = st.session_state.indices_measured
        indices_unmeasured = st.session_state.indices_unmeasured
        img = st.session_state.img
        save_dir = st.session_state.save_dir
        seed_points = st.session_state.seed_points
        budget = st.session_state.budget
        seed_img_path = st.session_state.seed_img_path 
        sc_f = st.session_state.sc_f 
        # microscope = st.session_state.microscope

        uri = ""
        microscope = Pyro5.api.Proxy(uri)
        microscope.activate_camera()



        norm_ = lambda x: (x - x.min()) / x.ptp()
        
        def plot_result_dkl(indices, obj):
            fig, ax = plt.subplots(1, 1, figsize=(3, 3))
            obj = norm_(obj)
            scatter = ax.scatter(indices[:, 1], indices[:, 0], s=32, c=obj, marker='s')
            next_point = indices[obj.argmax()]  # maximize the acquisition function
            ax.scatter(next_point[1], next_point[0], marker='x', c='k')
            ax.set_title("Acquisition function values")
            fig.colorbar(scatter, ax=ax, label='Objective Value')
            return fig
        
        data_dim = X_measured.shape[-1]
        key1, key2 = gpax.utils.get_keys()

        if st.session_state.exploration_step < budget:## before the budget is reached
            e = st.session_state.exploration_step
            st.write("{}/{}".format(e + 1, budget))

            
            st.markdown(f"### Policy intervention")
            # Ensure the session state for beta is initialized at the beginning or reset
            if 'beta' not in st.session_state:
                st.session_state.beta = 0.25  # Default initialization

            # Allow user to modify beta only if they haven't proceeded yet
            if 'proceed' not in st.session_state or not st.session_state.proceed:#### wont run if proceed is in session state
                beta_options = [0.25, 0.50, 0.75, 1.0]
                st.session_state.beta = st.selectbox(
                    'Beta Value for Step', 
                    options=beta_options, 
                    index=beta_options.index(st.session_state.beta)  # Ensures current session value is selected
                )
                
            beta = st.session_state.beta
            
            
            # Make sure the session state for steps_with_same_beta is initialized at the beginning or reset
            if 'steps_with_same_beta' not in st.session_state:
                st.session_state.steps_with_same_beta = 1  # Default initialization

            # Allow user to modify steps_with_same_beta only if they haven't proceeded yet
            if 'proceed' not in st.session_state or not st.session_state.proceed:#### wont run if proceed in session state
                st.session_state.steps_with_same_beta = st.number_input(
                    'Number of Steps with Same Beta', 
                    min_value=1, 
                    max_value=st.session_state.budget - st.session_state.exploration_step, 
                    value=st.session_state.steps_with_same_beta,
                    step=1
    )
            # st.markdown(f"### Scalarizer intervention to be added here")
            proceed = st.button('Proceed to Next Step')
            # Store the button press in session state to lock the input
            if proceed:
                st.session_state.proceed = True  # This flag prevents further changes to steps_with_same_beta

            st.image(seed_img_path, caption='Seed Points', use_column_width=True)


            # Logic to execute after pressing the 'Proceed' button
            if 'proceed' in st.session_state and st.session_state.proceed:
                del st.session_state['proceed']#-----------------------> important to delete the key
                for _ in range(st.session_state.steps_with_same_beta):
                    print(budget)
                    if st.session_state.exploration_step >= budget:# if true it breaks
                        break
                    dkl = gpax.viDKL(data_dim, 2)
                    # import pdb; pdb.set_trace()
                    dkl.fit(key1, np.array(X_measured), np.array(y_measured), num_steps=100, step_size=0.05)
                    obj = gpax.acquisition.UCB(key2, dkl, np.array(X_unmeasured), beta=beta, maximize=True)
                    next_point_idx = obj.argmax()
                    microscope.set_beam_pos(int(indices_unmeasured[next_point_idx, 1]), int(indices_unmeasured[next_point_idx, 0]))
                    microscope.acquire_camera(exposure=0.1)
                    array_list, shape, dtype = microscope.get_eels()
                    spectrum = np.array(array_list, dtype=dtype).reshape(shape)                   
                    measured_point = sc_f(spectrum)
                    
                    fig = plot_result_dkl(np.array(indices_unmeasured), obj)
                    acq_func_img_path = os.path.join(save_dir, f"acquisition_function_{st.session_state.exploration_step}.png")
                    plt.savefig(acq_func_img_path)
                    plt.close()
                    st.image(acq_func_img_path, caption=f'Acquisition Function Plot {st.session_state.exploration_step + 1}', use_column_width=True)
                    logging.info(f"Acquisition function plot saved to {acq_func_img_path}")
                    
                    X_measured = np.append(np.array(X_measured), np.array(X_unmeasured)[next_point_idx][None], 0)
                    X_unmeasured = np.delete(np.array(X_unmeasured), next_point_idx, 0)
                    y_measured_unnor = np.append(np.array(y_measured_unnor), measured_point)
                    y_measured = norm_(np.asarray(y_measured_unnor))
                    # y_unmeasured = np.delete(np.array(y_unmeasured), next_point_idx)
                    indices_measured = np.append(np.array(indices_measured), np.array(indices_unmeasured)[next_point_idx][None], 0)
                    indices_unmeasured = np.delete(np.array(indices_unmeasured), next_point_idx, 0)
                    
                    st.session_state.X_measured = X_measured
                    st.session_state.X_unmeasured = X_unmeasured
                    st.session_state.y_measured = y_measured
                    #st.session_state.y_unmeasured = y_unmeasured
                    st.session_state.y_measured_unnor = y_measured_unnor
                    st.session_state.indices_measured = indices_measured
                    st.session_state.indices_unmeasured = indices_unmeasured
                    
                    st.session_state.exploration_step += 1

        if st.session_state.exploration_step < budget:####### before the budget is reached
            plt.imshow(np.array(img), origin="lower", cmap='gray')
            plt.scatter(
                np.array(indices_measured)[seed_points:, 1],
                np.array(indices_measured)[seed_points:, 0],
                c=np.arange(len(np.array(indices_measured)[seed_points:])),
                s=50,
                cmap="Reds"
            )
            plt.colorbar()
            final_points_img_path = os.path.join(save_dir, "intermediate_points.png")
            plt.savefig(final_points_img_path)
            plt.close()
            st.image(final_points_img_path, caption='Points sampled till now', use_column_width=True)
            logging.info(f"intermediate points saved to {final_points_img_path}")

        if st.session_state.exploration_step >= budget:####### after the budget is reached
            plt.imshow(np.array(img), origin="lower", cmap='gray')
            plt.scatter(
                np.array(indices_measured)[seed_points:, 1],
                np.array(indices_measured)[seed_points:, 0],
                c=np.arange(len(np.array(indices_measured)[seed_points:])),
                s=50,
                cmap="Reds"
            )
            plt.colorbar()
            final_points_img_path = os.path.join(save_dir, "final_points.png")
            plt.savefig(final_points_img_path)
            plt.close()
            st.image(final_points_img_path, caption='Final Points', use_column_width=True)
            logging.info(f"Final points saved to {final_points_img_path}")

if __name__ == "__main__":
    main()