import numpy as np
import pandas as pd
import scipy.signal as sp
import scipy.stats as sts

from scipy.optimize import curve_fit
from scipy.ndimage import gaussian_filter1d as gfilter
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit

_ORDER_PARAMS = [
    r'$\dot{L}_{\text{eval}}$', r'$\sigma_{\theta}$', 
    r'$\min C_{NTK}$', 
    r'$\text{CKA}(K_Y, K_{NTK})$', 'edge_alignment', r'$\lambda_0$', 
    r'$C_{\infty}$', 'norm_spatial_mean_log_D', #r'$\beta_{FORT}$', 
]


def log_log_grad(xt, t):
    return np.gradient(np.log(1+xt), np.log(1+t))

def find_first_above(signal, threshold):
    # Find indices where the signal is below the threshold
    indices = np.where(signal > threshold)[0]
    # Return the first index, or None if no such index exists
    return indices[0] if indices.size > 0 else 0

def find_t_to_learn(R, epochs, thresh=1e-2):
    idx_stable = []
    for i in range(R.shape[1]):
        _r = abs(R[:,i])[::-1] ** 2
        idx_stable.append(find_first_above(_r, thresh))

    idx_stable = np.array(idx_stable)
    t_stable = np.array([1 + epochs[-t] for t in idx_stable])
    idx_stable = np.arange(len(epochs))[-idx_stable]

    return idx_stable, t_stable

def find_pk_eval_grad(epoch, sig):
    #scaled_sig = -_stat_df[stat]
    scaled_sig = -sig
    scaled_sig = scaled_sig - scaled_sig.min()
    scaled_sig = scaled_sig / scaled_sig.max()
    _pks = sp.find_peaks(scaled_sig, prominence=0.25)[0]
    return _pks[-1]

def min_max_norm(x):
    new_x = x - x.min()
    return new_x / new_x.max()

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def sigmoid_to_fit(x, A, B, mu, w):
    return A + (B - A) * sigmoid( (x - mu) / w)

def find_crossing(t, y, k_cutoff, t_guess=None):
    raw_sig = pd.Series(y)
    raw_sig = raw_sig.expanding(1).std().fillna(0).values 
    df = pd.DataFrame({'time': t, 'values': y})
    df.set_index('time', inplace=True)
    interpolated_df = df.interpolate(method='linear')
    interpolated_values = interpolated_df['values'].to_numpy()
    _T = df.index.values
    _X = interpolated_values

    diffs = np.diff(_X < k_cutoff)
    indices = np.where(diffs)[0]

    sig_vals = raw_sig[indices]
    T_solutions = _T[indices] + (k_cutoff - _X[indices]) * (_T[indices + 1] - _T[indices]) / (_X[indices + 1] - _X[indices])

    if len(indices) > 1:
        if t_guess is not None:
            diff = abs(T_solutions - t_guess)
            best_ind = np.argmin(diff)
            return T_solutions[best_ind], sig_vals[best_ind]
        else:
            return T_solutions[0], sig_vals[0]

    return T_solutions[0], sig_vals[0]

def critical_points_min_cntk(t, y, t_init=None):
    t_pk, unc = find_crossing(t, y, 0, t_guess=t_init)
    t_left, _ = find_crossing(t, y, unc, t_guess=t_pk)
    t_right, _ = find_crossing(t, y, -unc, t_guess=t_pk)

    return t_pk, t_left, t_right


def get_pk_info(p_df):
    l_bounds = [0, 0, 0, 0]
    u_bounds = [1, 1, 1e4, 1e3]
    _info = {}

    _pk_guess = p_df['$\\dot{L}_{\\text{eval}}$'].argmin()
    
    t = p_df.epoch.values
    t_interpolator = interp1d(np.arange(len(t)), t, kind='linear')
    save_t = 0
    
    for i, p in enumerate(_ORDER_PARAMS):
        if p == '$\\dot{L}_{\\text{eval}}$':
            sig = min_max_norm(-p_df[p].values)
        else:
            sig = min_max_norm(p_df[p].values)
    
        if p in ['$\\dot{L}_{\\text{eval}}$', '$\\sigma_{\\theta}$']:
            _pk, _pk_info = sp.find_peaks(sig, prominence=0.5, width=(None,None))
            if len(_pk) == 0:
                _pk, _left, _right = np.nan, np.nan, np.nan
            else:
                if len(_pk) > 1 and p == '$\\dot{L}_{\\text{eval}}$':
                    pk_to_use = 1
                else:
                    pk_to_use = 0

                if p == '$\\dot{L}_{\\text{eval}}$':
                    save_t = _pk[pk_to_use]
                try:
                    _left = t_interpolator(_pk[pk_to_use] - _pk_info['widths'][pk_to_use])
                except ValueError:
                    _left = t[0]
                try:
                    _right = t_interpolator(_pk[pk_to_use] + _pk_info['widths'][pk_to_use])
                except ValueError:
                    _right = t[-1]

                _pk = t[_pk[pk_to_use]]
    
        elif p == r'$\min C_{NTK}$':
            try:
                _pk, _left, _right = critical_points_min_cntk(t, p_df[p].values, t_init=save_t)
            except IndexError:
                _pk, _left, _right = np.nan, np.nan, np.nan

        #MAGMA
        elif p == 'norm_spatial_mean_log_D':          
            raw_sig = pd.Series(p_df['norm_spatial_mean_log_D'].values)
            sig = raw_sig.expanding(1).std().fillna(0).values 
            # raw_sig = pd.Series(p_df['norm_spatial_mean_log_D'].values)
            # sig = raw_sig.expanding(1).std().fillna(0).values 

            if sig[-1] > sig[0]:
                _initial_guess = [0, 1, _pk_guess, 20]
            else:
                _initial_guess = [1, 0, _pk_guess, 20]
    
            popt, pcov = curve_fit(sigmoid_to_fit, t, sig, p0=_initial_guess, maxfev=10000, bounds=(l_bounds, u_bounds))
            #popt, pcov = curve_fit(new_sigmoid_to_fit, t, sig, p0=_initial_guess, maxfev=10000, bounds=(l_bounds, u_bounds))
            _A, _B, _mu, _w = popt
            #print(_mu, _w)
            #print(_mu - 3 * _w, _mu - 2 * _w, _mu - _w)
            # _left = _mu - 3 * _w
            # _pk = _mu - 2 * _w
            # _right = _mu - _w
            _left = _mu - 3 * _w
            _pk = _mu - 2 * _w
            _right = _mu - _w
        elif p == r'$C_{\infty}$':
            raw_sig = pd.Series(p_df[p].values[::-1])
            sig = raw_sig.expanding(1).std().fillna(0).values[::-1]
            sig = min_max_norm(sig)
            
            _initial_guess = [1, 0, _pk_guess, 20]
    
            popt, pcov = curve_fit(sigmoid_to_fit, t, sig, p0=_initial_guess, maxfev=10000, bounds=(l_bounds, u_bounds))
            #popt, pcov = curve_fit(new_sigmoid_to_fit, t, sig, p0=_initial_guess, maxfev=10000, bounds=(l_bounds, u_bounds))
            _A, _B, _mu, _w = popt
    
            # _left = _mu #+ _w
            # _pk = _mu + 2 * _w
            # _right = _mu + 4 * _w
            _left = _mu #+ _w
            _pk = _mu + 2 * _w
            _right = _mu + 4 * _w
        else:
            sig = min_max_norm(p_df[p].values)
            if sig[-1] > sig[0]:
                _initial_guess = [0, 1, _pk_guess, 20]
            else:
                _initial_guess = [1, 0, _pk_guess, 20]
    
            popt, pcov = curve_fit(sigmoid_to_fit, t, sig, p0=_initial_guess, bounds=(l_bounds, u_bounds))
            _A, _B, _mu, _w = popt
            _pk = _mu
            _left = _mu - 2 * _w
            _right = _mu + 2 * _w

        _info[f'{p} pk'] = _pk
        _info[f'{p} start'] = _left
        _info[f'{p} end'] = _right
    
    return _info