import pickle
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from ipywidgets import interact
import ipywidgets as widgets
from scipy.interpolate import RegularGridInterpolator
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import os

from scipy.optimize import curve_fit, lsq_linear
from sklearn.linear_model import LinearRegression
from scipy.optimize import least_squares
from sklearn.preprocessing import PolynomialFeatures, MinMaxScaler, FunctionTransformer
from scipy.optimize import minimize
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from scipy.optimize import LinearConstraint
import warnings
warnings.filterwarnings("ignore", message="delta_grad == 0.0. Check if the approximated function is linear")


class DataLoader:
    def __init__(self, file_path, if_fit=True, if_force_reload=False):
        self.file_path = file_path
        self.preprocess_dir, self.paths = self.prepare_directories_and_paths()
        self.if_hit_cache, self.data = self.load_data(if_force_reload)
        if self.if_hit_cache:
            self.processed_data = self.data
        else:
            self.processed_data = self.cal_average()

    def load_data(self, if_force_reload=False):
        if_hit_cache = False
        if os.path.exists(self.paths['summary_pkl']):
            if if_force_reload:
                print("Loading from Raw Data...")
                return if_hit_cache, self.read_data_from_raw()
                print("Raw data loaded")
            else:
                if_hit_cache = True
                print("Loading DataFrame...")
                df = pd.read_pickle(self.paths['summary_pkl'])
                return if_hit_cache, df
                print("DataFrame loaded")
        else:
            return if_hit_cache, self.read_data_from_raw()

    def recreate_data_from_index_to_key(self, index_to_key):
        return [{key: []} for key in index_to_key.flatten()]

    def read_data_from_raw(self):
        print("Loading data from raw file...")
        with open(self.file_path, 'rb') as file:
            data = pickle.load(file)
        print("Raw Data loaded")
        return data

    def prepare_directories_and_paths(self):
        base_dir = os.path.dirname(self.file_path)
        preprocess_dir = os.path.join(base_dir, 'preprocess')
        if not os.path.exists(preprocess_dir):
            os.makedirs(preprocess_dir)
        paths = {
            'summary_pkl': os.path.join(preprocess_dir, 'summary.pkl'),
        }

        return preprocess_dir, paths

    def cal_average(self):
        result = []
        for entry in self.data:
            for key, lists in entry.items():
                # Ensure all lists have length 3 by padding with zeros
                if len(lists[0]) == 4:
                    padded_lists = [lst[:3] + [0] for lst in lists if lst[3] == False]# and lst[0]>0]
                    if len(padded_lists) == 0:
                        padded_lists = [[0, 0, 0, 0]]
                    all_firing_rate = [lst[0] for lst in padded_lists]
                    unique, counts = np.unique(all_firing_rate, return_counts=True)
                    averages = np.mean(padded_lists, axis=0)
                    variances = np.var(padded_lists, axis=0)
                    CV = np.array([])
                    columns = ["g_syn", "v_rest", "g_leak", "mean_FR", "mean_na", "mean_syn_current", "mean_CV",
                            "var_FR", "var_na", "var_syn_current", "unique_FR", "count_FR"]
                elif len(lists[0]) == 5:
                    padded_lists = [lst[:3] + [lst[4]] for lst in lists if lst[3] == False]
                    if len(padded_lists) == 0:
                        padded_lists = [[0, 0, 0, 0]]
                    all_firing_rate = [lst[0] for lst in padded_lists]
                    unique, counts = np.unique(all_firing_rate, return_counts=True)
                    averages = np.mean(padded_lists, axis=0)
                    variances = np.var(padded_lists, axis=0)
                    #print(averages.shape, variances.shape)
                    CV = np.sqrt(variances[3]) / averages[3]
                    columns = ["g_syn", "v_rest", "g_leak", "mean_FR", "mean_na", "mean_syn_current", "mean_CV",
                            "var_FR", "var_na", "var_syn_current", "unique_FR", "count_FR"]
                elif len(lists[0]) == 6:
                    padded_lists = [lst[:3] + lst[4:6] for lst in lists if lst[3] == False]
                    if len(padded_lists) == 0:
                        padded_lists = [[np.nan, np.nan, np.nan, np.nan, np.nan]]
                    all_firing_rate = [lst[0] for lst in padded_lists]
                    unique, counts = np.unique(all_firing_rate, return_counts=True)
                    averages = np.mean(padded_lists, axis=0)
                    variances = np.var(padded_lists, axis=0)
                    CV = np.sqrt(variances[3]) / averages[3]
                    columns = ["g_syn", "v_rest", "g_leak", "mean_FR", "mean_na", "mean_syn_current", "mean_CV", "mean_leak_na",
                            "var_FR", "var_na", "var_syn_current", "var_leak_na", "unique_FR", "count_FR"]
                else:
                    padded_lists = [lst + [0] * (3 - len(lst)) for lst in lists]
                    averages = np.mean(padded_lists, axis=0) / 3
                    variances = np.var(padded_lists, axis=0) / 9
                    unique, counts = np.array([]), np.array([])  # Handle old data case
                    # Define column names
                    columns = ["g_syn", "v_rest", "g_leak", "mean_FR", "mean_na", "mean_syn_current",
                            "var_FR", "var_na", "var_syn_current", "unique_FR", "count_FR"]
                # Append the result, ensuring unique and counts are stored as arrays
                result.append(list(key) + list(averages[:3]) + [CV] + [averages[-1]] + list(variances[:3]) + [variances[-1]] + [unique] + [counts])

        

        # Create DataFrame
        df = pd.DataFrame(result, columns=columns)

        # Ensure "unique_FR" and "count_FR" store NumPy arrays
        df["unique_FR"] = df["unique_FR"].apply(lambda x: np.array(x))
        df["count_FR"] = df["count_FR"].apply(lambda x: np.array(x))

        self.save_arrays(self.paths, df)
        return df

    def save_arrays(self, paths, df):
        # Save the processed data
        df.to_pickle(paths['summary_pkl'])


class Analysis:
    def __init__(self, data_loader, target_firing_rate, synaptic_weight=10):
        self.data_loader = data_loader

        self.df = data_loader.processed_data
        self.target_firing_rate = target_firing_rate
        self.estimate_total_energy(synaptic_weight=synaptic_weight)
        self.estimate_Fano()
        self.extract_fit_g_syn_FR("fit_g_syn_FR.pkl")

    def extract_fit_g_syn_FR(self, filename):
        file_path = os.path.join(self.data_loader.preprocess_dir, filename)
        if_file_exists = os.path.exists(file_path)
        if if_file_exists:
            with open(os.path.join(self.data_loader.preprocess_dir, "fit_g_syn_FR.pkl"), 'rb') as file:
                self.all_popt = pickle.load(file)
        else:
            self.all_popt = self.fit_g_syn_FR()
            with open(os.path.join(self.data_loader.preprocess_dir, "fit_g_syn_FR.pkl"), 'wb') as file:
                pickle.dump(self.all_popt, file)

    def estimate_total_energy(self, synaptic_weight=80):
        self.df['total_energy'] = synaptic_weight * np.abs(self.df['mean_syn_current']) + np.abs(self.df['mean_na']) + np.abs(self.df['mean_leak_na'])
        # eps_sig: self.df["mean_syn_current"]
        # eps_bg: self.df["mean_na"] + self.df["mean_leak_na"]
        self.df['perc_syn_energy'] = synaptic_weight * np.abs(self.df['mean_syn_current']) / self.df['total_energy']
    
    def estimate_Fano(self):
        self.df["Fano"] = self.df["var_FR"] / self.df["mean_FR"]

    def fit_FR_noise(self, if_plot=False):
        def _arc(x, a):
            return a * x * (1-x)
        g_leak_set = sorted(self.df["g_leak"].unique())
        v_rest_set = sorted(self.df["v_rest"].unique())[::-1]
        all_data = []

        if if_plot:
            num_rows = len(g_leak_set)
            num_cols = len(v_rest_set)
            fig, axes = plt.subplots(num_rows//2+1, num_cols//2+1, figsize=(4 * num_cols, 3 * num_rows), sharex=True, sharey=True)

        #energy_max, energy_min = np.abs(self.df['mean_syn_current']).max(), np.abs(self.df['mean_syn_current']).min()
        for i, g_leak in enumerate(g_leak_set):
            for j, v_rest in enumerate(v_rest_set):
                subset_df = self.df[(self.df["g_leak"] == g_leak) & (self.df["v_rest"] == v_rest)]
                subset_df = subset_df.sort_values(by="g_syn", ascending=True)
                x = subset_df['mean_FR']
                y = subset_df['var_FR']
                popt, _ = curve_fit(_arc, x, y)
                
                if if_plot:
                    if i%2==0 and j%2==0:
                        axes[i//2,j//2].plot(x, y, 'o', markersize=10)
                        axes[i//2,j//2].plot(x, _arc(x, popt[0]), 'r', linewidth=3)
                        axes[i//2,j//2].set_xlim([0, 1.0])
                        axes[i//2,j//2].set_xticklabels([])
                all_data.append([v_rest, g_leak, popt[0]])
        df_noise = pd.DataFrame(all_data, columns=["v_rest", "g_leak", "noise"])

        self.df_noise = df_noise
        return df_noise
    
    
    def plot_data_cube(self):
        x, y, z = self.df["g_syn"], self.df["v_rest"], self.df["g_leak"]
        colors_avg1 = self.df["mean_FR"]
        #colors_avg2 = self.df["var_FR"]/self.df["mean_FR"]
        colors_avg2 = self.df["perc_syn_energy"]
        colors_avg3 = np.abs(self.df["mean_syn_current"]*self.df["mean_FR"])
        colors_avg4 = np.abs(self.df["mean_leak_na"]) #+ np.abs(self.df["mean_na"])
        colors_avg5 = self.df["total_energy"]
        colors_avg6 = self.df["mean_CV"]
        # Create a figure with three subplots
        fig = plt.figure(figsize=(20, 14))

        # Plot 1: avg_1
        ax1 = fig.add_subplot(231, projection="3d")
        scatter1 = ax1.scatter(x, y, z, c=colors_avg1, cmap="jet", s=30)
        cbar1 = fig.colorbar(scatter1, ax=ax1, shrink=0.6, orientation="horizontal", fraction=0.05, pad=0.15)
        cbar1.set_label("P(Spike)")
        ax1.set_title("P(Spike)")
        ax1.set_xlabel("g_syn [uS]")
        ax1.set_ylabel("v_rest [mV]")
        ax1.set_zlabel("g_leak [S/cm2]", labelpad=10)
        ax1.invert_yaxis()
        # Plot 2: avg_2
        ax2 = fig.add_subplot(232, projection="3d")
        scatter2 = ax2.scatter(x, y, z, c=colors_avg2, cmap="jet", s=30)
        cbar2 = fig.colorbar(scatter2, ax=ax2, shrink=0.6, orientation="horizontal", fraction=0.05, pad=0.15)
        cbar2.set_label("Fano Number")
        ax2.set_title("Fano")
        ax2.set_xlabel("g_syn [uS]")
        ax2.set_ylabel("v_rest [mV]")
        ax2.set_zlabel("g_leak [S/cm2]", labelpad=10)
        ax2.invert_yaxis()
        # Plot 3: avg_3
        ax3 = fig.add_subplot(233, projection="3d")
        scatter3 = ax3.scatter(x, y, z, c=colors_avg3, cmap="jet", s=30)
        cbar3 = fig.colorbar(scatter3, ax=ax3, shrink=0.6, orientation="horizontal", fraction=0.05, pad=0.15)
        cbar3.set_label("mean_syn_current [nC]")
        ax3.set_title("Single Spike Energy by Syn. Current")
        ax3.set_xlabel("g_syn [uS]")
        ax3.set_ylabel("v_rest [mV]")
        ax3.set_zlabel("g_leak [S/cm2]", labelpad=10)
        ax3.invert_yaxis()

        ax4 = fig.add_subplot(234, projection="3d")
        scatter4 = ax4.scatter(x, y, z, c=colors_avg4, cmap="jet", s=30)
        cbar4 = fig.colorbar(scatter4, ax=ax4, shrink=0.6, orientation="horizontal", fraction=0.05, pad=0.15)
        cbar4.set_label("mean_syn_current [nC]")
        ax4.set_title("Background Na Current for V_rest per Sec.")
        ax4.set_xlabel("g_syn [uS]")
        ax4.set_ylabel("v_rest [mV]")
        ax4.set_zlabel("g_leak [S/cm2]", labelpad=10)
        ax4.invert_yaxis()

        ax5 = fig.add_subplot(235, projection="3d")
        scatter5 = ax5.scatter(x, y, z, c=colors_avg5, cmap="jet", s=30)
        cbar5 = fig.colorbar(scatter5, ax=ax5, shrink=0.6, orientation="horizontal", fraction=0.05, pad=0.15)
        cbar5.set_label("Current [nC]")
        ax5.set_title(f"Total Energy")
        ax5.set_xlabel("g_syn [uS]")
        ax5.set_ylabel("v_rest [mV]")
        ax5.set_zlabel("g_leak [S/cm2]", labelpad=10)
        ax5.invert_yaxis()

        ax6 = fig.add_subplot(236, projection="3d")
        scatter6 = ax6.scatter(x, y, z, c=colors_avg6, cmap="jet", s=30)
        cbar6 = fig.colorbar(scatter6, ax=ax6, shrink=0.6, orientation="horizontal", fraction=0.05, pad=0.15)
        cbar6.set_label("CV [a.u.]")
        ax6.set_title(f"Mean CV")
        ax6.set_xlabel("g_syn [uS]")
        ax6.set_ylabel("v_rest [mV]")
        ax6.set_zlabel("g_leak [S/cm2]", labelpad=10)
        ax6.invert_yaxis()
        # Adjust layout and show the plot
        plt.tight_layout()
        plt.show()

        
    
 
    
    

    def _gen_meta_tuning(self, x_contrast, std, max, if_plot=False):
        """
        Generate the meta_tuning curve based on gaussian with mean=0
        x_contrast: the contrast of the input [-90 deg, 90 deg]
        std: the standard deviation of the tuning curve
        max: the maximum g_syn
        """
        if if_plot:
            x = x_contrast
            y = max * np.exp(-0.5 * (x / std) ** 2)
            plt.plot(x, y)
            plt.xlabel('Contrast [deg.]')
            plt.ylabel('g_syn [uS]')
            plt.title('Meta Tuning Curve')
            plt.show()

        return max * np.exp(-0.5 * (x_contrast / std) ** 2)
    
    def cal_energy_by_g_syn(self, x_contrast, std, target_firing_rate, overshoot=0.05, p=None, if_plot=False, if_strict_homeo=True):
        # assume uniform distribution now
        df_list = []
        if p is None:
            p = 1 / x_contrast.size
        for popt in self.all_popt.items():
            x = self.inverse_genlogistic(target_firing_rate, popt[1][0], popt[1][1], popt[1][2])
            #print(x)
            if if_strict_homeo:
                v_rest, g_leak = popt[0]
                sub_df = self.df[(self.df["g_leak"] == g_leak) & (self.df["v_rest"] == v_rest)]
                sub_df = sub_df.sort_values(by="g_syn", ascending=True)
                x_g_syn, y_total_energy = sub_df["g_syn"].to_numpy(), sub_df["total_energy"].to_numpy()
                average_total_energy = np.interp(x, x_g_syn, y_total_energy)
                df_list.append([v_rest, g_leak, average_total_energy])
            else:
                meta_tuning = self._gen_meta_tuning(x_contrast, std, x * (1 + overshoot), if_plot=False)
                v_rest, g_leak = popt[0]
                sub_df = self.df[(self.df["g_leak"] == g_leak) & (self.df["v_rest"] == v_rest)]
                sub_df = sub_df.sort_values(by="g_syn", ascending=True)

                x_g_syn, y_total_energy = sub_df["g_syn"].to_numpy(), sub_df["total_energy"].to_numpy()

                y_interp_total_energy = np.interp(meta_tuning, x_g_syn, y_total_energy)
                average_total_energy = np.sum(y_interp_total_energy * p)
                df_list.append([v_rest, g_leak, average_total_energy])

        self.df_energy = pd.DataFrame(df_list, columns=["v_rest", "g_leak", "average_total_energy"])

        return self.df_energy

    def inverse_genlogistic(self, y, b, slope, nu):
            if y <= 0 or y >= 1:  # Ensure valid range
                raise ValueError("y must be between 0 and 1 (exclusive)")
            
            return b - slope * np.log(y**(-nu) - 1)
    
    def gen_contrast_FR_tuning_by_fitting(self, x_contrast, std, target_firing_rate, overshoot=0.05, if_plot=False):
        def genlogistic_derivative(x, b, slope, nu):
            exp_term = np.exp(-(x - b) / slope)
            denominator = (1 + exp_term) ** (1 + 1 / nu)
            
            return exp_term / (nu * slope * denominator)


        if if_plot:
            g_leak_set = sorted(self.df["g_leak"].unique())
            v_rest_set = sorted(self.df["v_rest"].unique())[::-1]
            num_rows = len(g_leak_set)
            num_cols = len(v_rest_set)
            fig, axes = plt.subplots(num_rows, num_cols, figsize=(4 * num_cols, 3 * num_rows), sharex=True, sharey=True)
        if hasattr(self, 'all_popt'):
            pass
        else:
            self.all_popt = self.fit_g_syn_FR(if_plot=True)
        all_popt = self.all_popt
        
        results = {}
        for popt in all_popt.items():
            x = self.inverse_genlogistic(target_firing_rate, popt[1][0], popt[1][1], popt[1][2])
            df_dx = genlogistic_derivative(x, popt[1][0], popt[1][1], popt[1][2])
            meta_tuning = self._gen_meta_tuning(x_contrast, std, x * (1 + overshoot), if_plot=False)
            FR_tuning = self._genlogistic(meta_tuning, popt[1][0], popt[1][1], popt[1][2])
            est_sigma, fit_gauss = self._fit_gaussian(x_contrast, FR_tuning, FR_tuning.max()) # fit with gaussian
            fwhm = self._find_FWHM(x_contrast, FR_tuning)# Full width half maximum

            if if_plot:
                v_rest, g_leak = popt[0]
                i, j = g_leak_set.index(g_leak), v_rest_set.index(v_rest)
                ax = axes[i, j]
                ax.plot(x_contrast, FR_tuning, label='real')
                ax.plot(x_contrast, fit_gauss, label='fit', linestyle='--')
                ax.set_title(f"g_leak: {g_leak:.2g}, v_rest: {v_rest:.2g}, Mean FR: {np.mean(FR_tuning):.3g}", fontsize=10)
                if i == num_rows - 1:
                    ax.set_xlabel("x_contrast")
                if j == 0:
                    ax.set_ylabel("FR") 


            results[popt[0]] = [x, df_dx, est_sigma, fwhm, popt[1][0], popt[1][1], popt[1][2]]
        df_list = []
        for (v_rest, g_leak), (x, df_dx, est_sigma, fwhm, popt[1][0], popt[1][1], popt[1][2]) in results.items():
            df_list.append([v_rest, g_leak, x, df_dx, est_sigma, fwhm, popt[1][0], popt[1][1], popt[1][2]])

        # Create DataFrame
        self.df_broadening = pd.DataFrame(df_list, columns=["v_rest", "g_leak", "x", "df_dx", "est_sigma", "fwhm", "b", "slope", "nu"])
        return self.df_broadening
    
    def _find_FWHM(self, x, y):
        half_max = np.max(y) / 2
        idx = np.where(y >= half_max)[0]
        fwhm = x[idx[-1]] - x[idx[0]]
        return fwhm
    
    def gen_contrast_FR_tuning(self, x_contrast, std, target_firing_rate, overshoot=0.05, tol=0.2, if_plot=False):
        """
        Generate the tuning curve based on gaussian with mean=0
        x_contrast: the contrast of the input [-90 deg, 90 deg]
        std: the standard deviation of the tuning curve
        """
        g_leak_set = sorted(self.df["g_leak"].unique())
        v_rest_set = sorted(self.df["v_rest"].unique())[::-1]
        num_rows = len(g_leak_set)
        num_cols = len(v_rest_set)
        if if_plot:
            fig, axes = plt.subplots(num_rows, num_cols, figsize=(4 * num_cols, 3 * num_rows), sharex=True, sharey=True)

        # help me find the closest mean_FR to the target_firing rate in different sets of g_leak and v_rest


        #mean_fr = self.df["mean_FR"]
        #mean_fr_in_range = self.df[np.isclose(mean_fr, target_firing_rate, atol=tol)]
        closest_mean_fr_list = []

        # Iterate over each combination of g_leak and v_rest
        for g_leak in g_leak_set:
            for v_rest in v_rest_set:
                # Filter dataframe for the current combination of g_leak and v_rest
                subset_df = self.df[(self.df["g_leak"] == g_leak) & (self.df["v_rest"] == v_rest)]
                
                if not subset_df.empty:
                    # Find the closest mean_FR to target_firing_rate
                    closest_idx = np.abs(subset_df["mean_FR"] - target_firing_rate).idxmin()
                    if np.abs(subset_df.loc[closest_idx]["mean_FR"] - target_firing_rate) <= tol:
                        closest_mean_fr_list.append(subset_df.loc[closest_idx])
                    else:
                        pass
                    #closest_mean_fr_list.append(subset_df.loc[closest_idx])

        # Convert to DataFrame
        max_g_syn = closest_mean_fr_list[0]['g_syn']
        meta_tuning = self._gen_meta_tuning(x_contrast, std, max_g_syn*(1+overshoot), if_plot=False) # g_syn

        df_filter_by_fr = pd.DataFrame(closest_mean_fr_list)
        all_est_sigma = {}
        for i, g_leak in enumerate(g_leak_set):
            for j, v_rest in enumerate(v_rest_set):
                
                subset_df = self.df[(self.df["g_leak"] == g_leak) & (self.df["v_rest"] == v_rest)]
                subset_df = subset_df.sort_values(by="g_syn", ascending=True)  # Ascending order
                scale = df_filter_by_fr[(df_filter_by_fr["g_leak"] == g_leak) & (df_filter_by_fr["v_rest"] == v_rest)]['g_syn'].values[0]
                scale = scale / max_g_syn
                fr = np.interp(meta_tuning * scale, subset_df['g_syn'], subset_df['mean_FR'])
                est_sigma, fit_gauss = self._fit_gaussian(x_contrast, fr, fr.max())
                all_est_sigma[(v_rest, g_leak)] = est_sigma
                if if_plot:
                    ax = axes[i, j]
                    ax.plot(x_contrast, fr, label='real')
                    ax.plot(x_contrast, fit_gauss, label='fit', linestyle='--')
                    ax.set_title(f"g_leak: {g_leak}, v_rest: {v_rest}, Mean FR: {np.mean(fr):.3g}", fontsize=10)
                    if i == num_rows - 1:
                        ax.set_xlabel("x_contrast")
                    if j == 0:
                        ax.set_ylabel("FR")

        # Adjust layout for readability
        df_all_est_sigma = pd.DataFrame(
            [(v_rest, g_leak, est_sigma) for (v_rest, g_leak), est_sigma in all_est_sigma.items()],
            columns=["v_rest", "g_leak", "est_sigma"]
        )       
        if if_plot:
            plt.tight_layout()
            plt.show()

        return df_all_est_sigma
        

    def _fit_gaussian(self, x, y, init_amp):
        """
        Fit the gaussian curve
        x: the x axis
        y: the y axis
        """
        def gaussian(x, sigma, init_amp):
            return init_amp * np.exp(-0.5 * (x / sigma) ** 2)
        
        popt, _ = curve_fit(gaussian, x, y, p0=[180, init_amp])
        est_sigma, est_amp = popt[0], popt[1]
        fit_gauss = gaussian(x, est_sigma, est_amp)
        return est_sigma, fit_gauss 
                
    def plot_bar_CV(self, v_rests, g_leak_idxs, g_syn_idxs, categories):
        def _get_CV(df, v_rest, g_leak_idx, g_syn_idx):
            cell = df[(df['v_rest']==v_rest)]
            g_leak_set = sorted(cell["g_leak"].unique())
            g_leak = g_leak_set[g_leak_idx]
            cell_g_leak = cell[(cell['g_leak']==g_leak)]
            g_syn_set = sorted(cell_g_leak["g_syn"].unique())
            g_syn = g_syn_set[g_syn_idx]
            CV = cell_g_leak[(cell_g_leak['g_syn']==g_syn)]['mean_CV']
            return CV.values[0], g_leak, g_syn
            
        
        CV1, g_leak1, g_syn1 = _get_CV(self.df, v_rests[0], g_leak_idxs[0], g_syn_idxs[0])
        CV2, g_leak2, g_syn2 = _get_CV(self.df, v_rests[1], g_leak_idxs[1], g_syn_idxs[1])
        
        plt.figure(figsize=(3, 6))
        plt.bar([1,1.2], [CV1, CV2], width=0.1)
        plt.xticks([1, 1.2], categories)
        plt.ylabel('CV')
        plt.title('CV comparison')
        
        print(f"g_leak1: {g_leak1}, g_syn1: {g_syn1}, CV1: {CV1}")
        print(f"g_leak2: {g_leak2}, g_syn2: {g_syn2}, CV2: {CV2}")

    

    def _genlogistic(self, x, b, slope, nu):
            return (1 + np.exp(-(x - b) / slope))**(-1/nu)
    
    def _fit_single_curve_g_syn_FR(self, v_rest, g_leak, if_plot=False):
        """
        Fit the curve of g_syn vs FR
        """
        subset_df = self.df[(self.df["v_rest"] == v_rest) & (self.df["g_leak"] == g_leak)]
        subset_df = subset_df.sort_values(by="g_syn", ascending=True)
        x = subset_df['g_syn']
        #print(x)
        x = x #/ x.mean()
        y = subset_df['mean_FR']
        b_init = x.min()
        slope_init = 0.5
        nu_init = 0.5
        #print(b_init, slope_init, nu_init)
        popt, _ = curve_fit(self._genlogistic, x, y, maxfev=10000, bounds=([0, 0, 0], [20, 20, 1.5]))
        x_hat = np.linspace(x.min()/2, x.max()*1.5, 1000)
        y_hat = self._genlogistic(x_hat, *popt)
        if if_plot:
            plt.figure()
            plt.plot(x, y, 'o', label='real')
            plt.plot(x_hat, y_hat, label='fit', linestyle='--', color='gray')
            plt.xlabel('g_syn')
            plt.ylabel('mean_FR')
            plt.legend()
            plt.show()
        return popt
            
    def fit_g_syn_FR(self, if_plot=False):
        v_rest_set = sorted(self.df["v_rest"].unique())[::-1]
        g_leak_set = sorted(self.df["g_leak"].unique())
        all_popt = {}
        for v_rest in v_rest_set:
            for g_leak in g_leak_set:
                popt = self._fit_single_curve_g_syn_FR(v_rest, g_leak, if_plot)
                all_popt[(v_rest, g_leak)] = popt
        return all_popt
    




class WidthModel:
    def __init__(self, deg_width, scaler=None):
        self.deg_width = deg_width
        # Use a provided scaler or create our own (scaling to [0,1])
        self.scaler = scaler if scaler is not None else MinMaxScaler(feature_range=(0, 1))
        self.poly = PolynomialFeatures(degree=self.deg_width, include_bias=True)
        self.model = None

    def fit(self, X, y):
        """
        X: raw (unnormalized) 2D array or DataFrame with columns [v_rest, g_leak]
        y: target values (e.g. fwhm/2)
        """
        # If X is not a DataFrame, assume columns ['v_rest', 'g_leak']
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=['v_rest', 'g_leak'])
        X_norm = self.scaler.fit_transform(X)
        X_poly = self.poly.fit_transform(X_norm)
        self.model = LinearRegression(fit_intercept=False).fit(X_poly, y)
        return self

    def predict(self, X):
        """
        X: raw (unnormalized) input.
        Returns predictions.
        """
        if not isinstance(X, pd.DataFrame):
            # Use the scaler's feature names if available, otherwise default.
            cols = self.scaler.feature_names_in_ if hasattr(self.scaler, "feature_names_in_") else ['v_rest', 'g_leak']
            X = pd.DataFrame(X, columns=cols)
        X_norm = self.scaler.transform(X)
        X_poly = self.poly.transform(X_norm)
        return self.model.predict(X_poly)


class EnergyModel:
    def __init__(self, deg_energy, scaler=None):
        self.deg_energy = deg_energy
        self.scaler = scaler if scaler is not None else MinMaxScaler(feature_range=(0, 1))
        self.poly = PolynomialFeatures(degree=self.deg_energy, include_bias=True)
        self.model = None

    def fit(self, X, y):
        """
        X: raw input with columns [v_rest, g_leak]
        y: target values (e.g. average_total_energy)
        """
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=['v_rest', 'g_leak'])
        X_norm = self.scaler.fit_transform(X)
        X_poly = self.poly.fit_transform(X_norm)
        self.model = LinearRegression(fit_intercept=False).fit(X_poly, y)
        return self

    def predict(self, X):
        """
        X: raw input.
        Returns predictions.
        """
        if not isinstance(X, pd.DataFrame):
            cols = self.scaler.feature_names_in_ if hasattr(self.scaler, "feature_names_in_") else ['v_rest', 'g_leak']
            X = pd.DataFrame(X, columns=cols)
        X_norm = self.scaler.transform(X)
        X_poly = self.poly.transform(X_norm)
        return self.model.predict(X_poly)


class NoiseModel:
    def __init__(self, deg_noise, if_exponential_fit = False):
        self.deg_noise = deg_noise
        self.if_exponential_fit = if_exponential_fit
        # For the noise model we want different treatment:
        # v_rest scaled normally to [0,1] and g_leak scaled to [0,1] then reversed.
        reverse_pipeline = Pipeline([
            ('minmax', MinMaxScaler(feature_range=(0, 1))),
            ('reverse', FunctionTransformer(lambda x: 1 - x))
        ])
        self.scaler = ColumnTransformer(transformers=[
            ('v_rest_scaler', MinMaxScaler(feature_range=(0, 1)), ['v_rest']),
            ('g_leak_scaler', reverse_pipeline, ['g_leak'])
        ])
        self.poly = PolynomialFeatures(degree=self.deg_noise, include_bias=False)
        self.log_transformer = FunctionTransformer(np.log1p, validate=False) 
        
        self.model = None

    def fit(self, X, y):
        """
        X: raw DataFrame with columns ['v_rest', 'g_leak']
        y: target values (e.g. noise, note: your original code used noise-1)
        """
        # Ensure X is a DataFrame with proper column names.
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=['v_rest', 'g_leak'])
        
        X_trans = self.scaler.fit_transform(X)
        X_poly = self.poly.fit_transform(X_trans)
        if self.if_exponential_fit:
            # Apply log transformation to y for exponential fitting
            
            y = self.log_transformer.fit_transform(y)  # Ensure y is 1D
        # Enforce positive coefficients using scipy.optimize
        
        res = lsq_linear(X_poly, y, bounds=(0, np.inf))  # Lower bound 0 = non-negative coefficients
        
        class PositiveLinearModel:
            def __init__(self, coef, if_exponential_fit=False):
                self.if_exponential_fit = if_exponential_fit
                self.coef_ = coef
            def predict(self, X):
                if self.if_exponential_fit:
                    # Apply the inverse log transformation for predictions
                    return np.exp(X @ self.coef_)-1  # Inverse of log1p
                else:
                    # Regular linear prediction 
                    return X @ self.coef_

        self.model = PositiveLinearModel(res.x, if_exponential_fit=self.if_exponential_fit)
        return self

    def predict(self, X):
        """
        X: raw DataFrame with columns ['v_rest', 'g_leak'].
        Returns predictions.
        """
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=['v_rest', 'g_leak'])
        X_trans = self.scaler.transform(X)
        X_poly = self.poly.transform(X_trans)
        return self.model.predict(X_poly)

class NoiseModelPolyExponential:
    def __init__(self, deg_noise):
        self.deg_noise = deg_noise
        # For the noise model we want different treatment:
        # v_rest scaled normally to [0,1] and g_leak scaled to [0,1] then reversed.
        reverse_pipeline = Pipeline([
            ('minmax', MinMaxScaler(feature_range=(0, 1))),
            ('reverse', FunctionTransformer(lambda x: 1 - x))
        ])
        self.scaler = ColumnTransformer(transformers=[
            ('v_rest_scaler', MinMaxScaler(feature_range=(0, 1)), ['v_rest']),
            ('g_leak_scaler', reverse_pipeline, ['g_leak'])
        ])
        
        # NOTE: Changed include_bias to True to allow for a general polynomial 
        # with a constant term, which is required to correctly model P(x).
        self.poly = PolynomialFeatures(degree=self.deg_noise, include_bias=True)
        
        # Attributes to store the fitted coefficients
        self.poly_coeffs_ = None
        self.exp_coeffs_ = None

    def _model_func(self, params, X_scaled, X_poly):
        """Calculates y_pred based on the model: P(x) * exp(L(x))."""
        # 1. Unpack the single parameter vector into two sets of coefficients
        num_poly_features = X_poly.shape[1]
        poly_coeffs = params[:num_poly_features]
        exp_coeffs = params[num_poly_features:]

        # 2. Calculate the polynomial part: P(x) = c_0*1 + c_1*x_1 + c_2*x_2 + ...
        poly_part = X_poly @ poly_coeffs

        # 3. Calculate the exponential part: exp(L(x)) = exp(w_1*x_1 + w_2*x_2)
        #    X_scaled contains the original two scaled features for the linear part.
        exp_part = np.exp(X_scaled @ exp_coeffs)

        return poly_part * exp_part

    def _residuals(self, params, X_scaled, X_poly, y_true):
        """Calculates the difference between predicted and true y."""
        y_pred = self._model_func(params, X_scaled, X_poly)
        return y_pred - y_true

    def fit(self, X, y):
        """
        Fits the non-linear model using iterative least squares.
        Model: y = [polynomial(x_scaled)] * exp(linear(x_scaled))
        
        Args:
            X (pd.DataFrame): DataFrame with columns ['v_rest', 'g_leak'].
            y (pd.Series or np.array): Target values (e.g., noise).
        """
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=['v_rest', 'g_leak'])
        # Ensure y is a numpy array for calculations
        y = np.asarray(y)

        # 1. Scale features and generate polynomial basis matrix
        X_scaled = self.scaler.fit_transform(X)
        X_poly = self.poly.fit_transform(X_scaled)
        
        # 2. Define the total number of parameters to solve for
        num_poly_features = X_poly.shape[1]
        num_exp_features = X_scaled.shape[1] # This will be 2
        num_params = num_poly_features + num_exp_features
    
        # 3. Provide an initial guess for the parameters (can be tuned)
        p0 = np.ones(num_params)

        # 4. Run the non-linear optimization
        #    This finds the 'params' vector that minimizes the sum of squares
        #    of the returned values from `_residuals`.
        res = least_squares(
            self._residuals,
            p0,
            args=(X_scaled, X_poly, y),
            bounds=(0, np.inf),  # Ensure non-negative coefficients
        )

        # 5. Store the optimal coefficients found by the optimizer
        optimal_params = res.x
        self.poly_coeffs_ = optimal_params[:num_poly_features]
        self.exp_coeffs_ = optimal_params[num_poly_features:]
        
        return self

    def predict(self, X):
        """Predicts using the fitted non-linear model."""
        if self.poly_coeffs_ is None or self.exp_coeffs_ is None:
            raise RuntimeError("You must fit the model before making predictions.")

        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=['v_rest', 'g_leak'])

        # Apply the same transformations to the new data
        X_scaled = self.scaler.transform(X)
        X_poly = self.poly.transform(X_scaled)

        # Combine the stored coefficients into a single vector
        params = np.concatenate([self.poly_coeffs_, self.exp_coeffs_])
        
        # Calculate and return the predictions
        return self._model_func(params, X_scaled, X_poly)
# -----------------------------
# Container class
# -----------------------------

class LinearFit:
    def __init__(self, df_width, df_energy, df_noise, degrees, if_exponential_fit=False):
        """
        df_width, df_energy, df_noise: DataFrames containing columns 'v_rest', 'g_leak' and target values.
        degrees: a list/tuple [deg_width, deg_energy, deg_noise]
        """
        self.df_width = df_width
        self.df_energy = df_energy
        self.df_noise = df_noise
        self.deg_width = degrees[0]
        self.deg_energy = degrees[1]
        self.deg_noise = degrees[2]
        self.if_exponential_fit = if_exponential_fit
        self.df = self._combine()
        # We will use a common scaler for width and energy models (since they work on raw data)
        self.common_scaler = MinMaxScaler(feature_range=(0, 1))
        # Fit the common scaler on the DataFrame with proper columns.
        self.common_scaler.fit(self.df[['v_rest', 'g_leak']])
        # Instantiate the model classes.
        self.width_model = WidthModel(self.deg_width, scaler=self.common_scaler)
        self.energy_model = EnergyModel(self.deg_energy, scaler=self.common_scaler)
        #self.noise_model = NoiseModel(self.deg_noise, if_exponential_fit=self.if_exponential_fit)
        self.noise_model = NoiseModelPolyExponential(self.deg_noise)

    def _combine(self):
        # Merge the data frames on v_rest and g_leak.
        df_temp = pd.merge(self.df_width, self.df_energy, on=['v_rest', 'g_leak'], how='inner')
        df = pd.merge(df_temp, self.df_noise, on=['v_rest', 'g_leak'], how='inner')
        return df

    def fit_all(self, if_plot=False):
        """
        Fit the three models.
        For width: target = fwhm/2
        For energy: target = average_total_energy
        For noise: target = noise (here we fit on noise-1, so that later predictions add 1)
        Optionally, generate 3D visualizations.
        Returns a dict of the models.
        """
        X = self.df[['v_rest', 'g_leak']]
        self.width_model.fit(X, self.df['fwhm'] / 2)
        self.energy_model.fit(X, self.df['average_total_energy'])
        # For noise, note that we use the noise model’s own transformer.
        self.noise_model.fit(self.df[['v_rest', 'g_leak']], self.df['noise'] - 1)

        if if_plot:
            v_rest_range = np.linspace(self.df['v_rest'].min(), self.df['v_rest'].max(), 31)
            g_leak_range = np.linspace(self.df['g_leak'].min(), self.df['g_leak'].max(), 31)
            #v_rest_range = sorted(self.df['v_rest'].unique())
            #g_leak_range = sorted(self.df['g_leak'].unique())

            v_rest_grid, g_leak_grid = np.meshgrid(v_rest_range, g_leak_range)
            X_pred = np.column_stack((v_rest_grid.ravel(), g_leak_grid.ravel()))
            X_pred_df = pd.DataFrame(X_pred, columns=['v_rest', 'g_leak'])
            
            width_pred = self.width_model.predict(X_pred)
            energy_pred = self.energy_model.predict(X_pred)
            # For noise, our predict expects a DataFrame.
            noise_pred = self.noise_model.predict(X_pred_df) + 1  # add back the shift
            
            width_pred = width_pred.reshape(v_rest_grid.shape)
            energy_pred = energy_pred.reshape(v_rest_grid.shape)
            noise_pred = noise_pred.reshape(v_rest_grid.shape)
            
            fig = plt.figure(figsize=(30, 30), dpi=50)
            
            # Plot Width Model
            
            ax1 = fig.add_subplot(131, projection='3d')
            # Plot the surface and keep a reference to the mappable
            surf = ax1.plot_surface(g_leak_grid*10**6, v_rest_grid, width_pred, cmap='jet', alpha=0.7)

            # Overlay scatter points
            ax1.scatter(self.df['g_leak']*10**6, self.df['v_rest'], self.df['fwhm'] / 2, c='r', marker='o')
            ax1.invert_yaxis()
            ax1.set_title("Fitted Width Model")
            ax1.set_xlabel("g_leak [uS/cm2]")
            ax1.set_ylabel("v_rest [mV]")
            ax1.set_zlabel("fwhm / 2 [deg.]")
            


            # Plot Energy Model
            def c2ATP(x):
                return x * (10**(-9)) / (3*1.602e-19)
            
            ax2 = fig.add_subplot(132, projection='3d')
            surf = ax2.plot_surface(g_leak_grid*10**6, v_rest_grid, energy_pred*c2ATP(1), cmap='jet', alpha=0.7)
            ax2.scatter(self.df['g_leak']*10**6, self.df['v_rest'] , self.df['average_total_energy'].to_numpy()*c2ATP(1), c='r', marker='o')
            ax2.invert_yaxis()
            ax2.set_title("Fitted Energy Model")
            ax2.set_xlabel("g_leak [uS/cm2]")
            ax2.set_ylabel("v_rest [mV]")
            ax2.set_zlabel("Total_energy [APT/s]")
            
            # Plot Noise Model
            ax3 = fig.add_subplot(133, projection='3d')
            surf = ax3.plot_surface(g_leak_grid*10**6, v_rest_grid, noise_pred, cmap='jet', alpha=0.7)
            ax3.scatter(self.df['g_leak']*10**6, self.df['v_rest'],  self.df['noise'], c='r', marker='o')
            ax3.invert_yaxis()
            ax3.set_xticklabels([])
            ax3.set_yticklabels([])
            ax3.set_zticklabels([])
            ax3.set_title("Fitted Noise Model")
            #ax3.set_xlabel("g_leak [uS/cm2]")
            #ax3.set_ylabel("v_rest [mV]")
            #ax3.set_zlabel("Eta")
            
            plt.show()


        return {
            "common_scaler": self.common_scaler,
            "width_model": self.width_model,
            "energy_model": self.energy_model,
            "noise_model": self.noise_model,
        }

    def find_optimal_path(self, energy_levels=None, num_levels=50):
        """
        Find the optimal (v_rest, g_leak) that minimizes the noise model along contours
        of constant energy. The optimization is performed in the normalized space defined by the common scaler.
        
        For each target energy value, we solve:
            minimize noise(x)
            subject to: energy(x) = E_target,
        where x is in the normalized space.
        The returned dictionary has raw (unnormalized) v_rest and g_leak, along with the target energy,
        predicted noise, and predicted width.
        """
        # Use the common scaler for the optimization.
        v_rest_min, v_rest_max = self.df['v_rest'].min(), self.df['v_rest'].max()
        g_leak_min, g_leak_max = self.df['g_leak'].min(), self.df['g_leak'].max()

        if energy_levels is None:
            # Ensure we pass a DataFrame to avoid warnings.
            energy_preds = self.energy_model.predict(self.df[['v_rest', 'g_leak']])
            E_min, E_max = energy_preds.min(), energy_preds.max()
            energy_levels = np.linspace(E_min, E_max, num_levels)
        
        optimal_path = []
        
        # Optimization functions assume x is in normalized space.
        def noise_objective(x):
            raw_x = self.common_scaler.inverse_transform(x.reshape(1, -1))
            raw_df = pd.DataFrame(raw_x, columns=self.common_scaler.feature_names_in_)
            return self.noise_model.predict(raw_df)[0]
        
        bounds = [(-0.1, 1.1), (-0.1, 1.1)]
        x0 = np.array([1, 0])
        
        # Extract energy model coefficients in normalized space
        coef = self.energy_model.model.coef_
        linear_A = np.array([coef[1], coef[2]])
        offset = coef[0]
        
        for E_target in energy_levels:
            constraint = LinearConstraint(linear_A.reshape(1, -1), lb=E_target-offset, ub=E_target-offset)
            res = minimize(noise_objective, 
                           x0, 
                           method='trust-constr', 
                           bounds=bounds, 
                           constraints=[constraint],
                           tol=1e-12
                           )
            x0 = res.x  # Use the previous solution as the initial guess
            if res.success:
                raw_vals = self.common_scaler.inverse_transform(res.x.reshape(1, -1))[0]
                width_val = self.width_model.predict(np.array([raw_vals]))[0]
                optimal_path.append({
                    'v_rest': raw_vals[0],
                    'g_leak': raw_vals[1],
                    'if_valid': raw_vals[0] >= v_rest_min and raw_vals[0] <= v_rest_max and raw_vals[1] >= g_leak_min and raw_vals[1] <= g_leak_max,
                    'energy': E_target,
                    'noise': res.fun+1,
                    'width': width_val
                })
            else:
                optimal_path.append({
                    'energy': E_target,
                    'v_rest': np.nan,
                    'g_leak': np.nan,
                    'noise': np.nan,
                    'width': np.nan
                })
        return optimal_path
    
   

    

    def plot_noise_energy_contour(self, noise_levels=40, times=5):
        """
        Plot a filled contour of noise (background, log-transformed) with energy contour lines overlaid.
        Colorbar shows original scale of noise (before transformation).
        """
        def log_transform(x, times=3):
            for _ in range(times):
                x = np.log(x) + 1
            return x

        def inverse_log_transform(x, times=3):
            for _ in range(times):
                x = np.exp(x - 1)
            return x
        
        v_rest_range = np.linspace(self.df['v_rest'].min(), self.df['v_rest'].max(), 50)
        g_leak_range = np.linspace(self.df['g_leak'].min(), self.df['g_leak'].max(), 50)
        v_rest_grid, g_leak_grid = np.meshgrid(v_rest_range, g_leak_range)
        X_grid = np.column_stack((v_rest_grid.ravel(), g_leak_grid.ravel()))
        X_grid_df = pd.DataFrame(X_grid, columns=['v_rest', 'g_leak'])

        noise_pred = self.noise_model.predict(X_grid_df) + 1  # add back shift
        energy_pred = self.energy_model.predict(X_grid)

        noise_pred = noise_pred.reshape(v_rest_grid.shape)
        energy_pred = energy_pred.reshape(v_rest_grid.shape)

        fig, ax = plt.subplots(figsize=(6, 6), facecolor='none')

        noise_transformed = log_transform(noise_pred, times=times)
        contour_noise = ax.contourf(
            v_rest_grid, g_leak_grid, noise_transformed,
            cmap='jet', levels=noise_levels, alpha=0.9
        )

        energy_levels = np.linspace(energy_pred.min(), energy_pred.max(), 20)
        contour_energy = ax.contour(
            v_rest_grid, g_leak_grid, energy_pred,
            levels=energy_levels, colors='white'
        )
        #ax.clabel(contour_energy, inline=True, fontsize=6)

        ax.set_facecolor('none')

        # Add colorbar with original (un-transformed) tick labels
        #cb = fig.colorbar(contour_noise, ax=ax, label='Noise')
        #ticks_transformed = np.linspace(noise_transformed.min(), noise_transformed.max(), 6)
        #tick_labels_original = inverse_log_transform(ticks_transformed, times=times)
        #cb.set_ticks(ticks_transformed)
        #cb.set_ticklabels([f"{val:.3f}" for val in tick_labels_original])
        #ax.set_xlabel("v_rest")
        #ax.set_ylabel("g_leak")
        ax.invert_xaxis()
        #ax.set_xticklabels([])
        #ax.set_yticklabels([])
        
        
        self.optimal_points = self.find_optimal_path(num_levels=100)
        opt_v_rest = [pt['v_rest'] for pt in self.optimal_points if not np.isnan(pt['v_rest']) and pt['if_valid']]
        opt_g_leak = [pt['g_leak'] for pt in self.optimal_points if not np.isnan(pt['g_leak']) and pt['if_valid']]
        ax.plot(opt_v_rest, opt_g_leak, '-', lw=5, label='Optimal Path', color="C2")
        

        #plt.show()

        return ax
    
    def plot_noise_energy_contour_bg(self, noise_levels=40, times=5):
        """
        Plot a filled contour of noise (background, log-transformed) with energy contour lines overlaid.
        Colorbar shows original scale of noise (before transformation).
        """
        def log_transform(x, times=3):
            for _ in range(times):
                x = np.log(x) + 1
            return x

        def inverse_log_transform(x, times=3):
            for _ in range(times):
                x = np.exp(x - 1)
            return x
        
        v_rest_range = np.linspace(self.df['v_rest'].min(), self.df['v_rest'].max(), 50)
        g_leak_range = np.linspace(self.df['g_leak'].min(), self.df['g_leak'].max(), 50)
        v_rest_grid, g_leak_grid = np.meshgrid(v_rest_range, g_leak_range)
        X_grid = np.column_stack((v_rest_grid.ravel(), g_leak_grid.ravel()))
        X_grid_df = pd.DataFrame(X_grid, columns=['v_rest', 'g_leak'])

        noise_pred = self.noise_model.predict(X_grid_df) + 1  # add back shift
        energy_pred = self.energy_model.predict(X_grid)

        noise_pred = noise_pred.reshape(v_rest_grid.shape)
        energy_pred = energy_pred.reshape(v_rest_grid.shape)

        fig, ax = plt.subplots(figsize=(6, 6), facecolor='none')

        noise_transformed = log_transform(noise_pred, times=times)
        contour_noise = ax.contourf(
            v_rest_grid, g_leak_grid, noise_transformed,
            cmap='jet', levels=noise_levels, alpha=0.9
        )


        ax.set_facecolor('none')

        ax.invert_xaxis()
        
        

        #plt.show()

        return ax


