import numpy as np
import pandas as pd
import scipy.signal as sp
import scipy.stats as sts

from tqdm import tqdm

from scipy.optimize import curve_fit
from scipy.interpolate import interp1d

from sklearn.metrics import mean_squared_error

_ALL_FILTERS = [
    'model.architecture', 
    'model.hidden_features', 
    'model.hidden_layers', 
    'dset', 
    'model.architecture.scale', 
    'seed'
]

_HPARAMS = [
    'model.hidden_features', 
    'model.hidden_layers', 
    'dset', 
    'model.architecture.scale', 
]

def min_max_norm(x):
    new_x = x - x.min()
    return new_x / new_x.max()

def reverse_rolling_std(sig):
    raw_sig = pd.Series(sig[::-1])
    return raw_sig.expanding(1).std().fillna(0).values[::-1]

def arr_reverse_rolling_std(sigs):
    processed_sigs = []
    for sig in sigs:
        processed_sigs.append(reverse_rolling_std(sig))
    return np.array(processed_sigs)

def reverse_rolling_mean(sig):
    raw_sig = pd.Series(sig[::-1])
    return raw_sig.expanding(1).mean().fillna(0).values[::-1]

def arr_reverse_rolling_mean(sigs):
    processed_sigs = []
    for sig in sigs:
        processed_sigs.append(reverse_rolling_mean(sig))
    return np.array(processed_sigs)

def log_log_grad(xt, t):
    return np.gradient(np.log(1+xt), np.log(1+t))

def arr_log_log_grad(xts, ts):
    log_grads = []
    for xt, t in zip(xts, ts):
        log_grads.append(log_log_grad(xt, t))
    return np.array(log_grads)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def sigmoid_to_fit(x, A, B, mu, w):
    return A + (B - A) * sigmoid( (x - mu) / w)

def jac_sigmoid_to_fit(x, A, B, mu, w):
    """
    Jacobian function for the sigmoid_to_fit model.

    Parameters:
        x : array-like
            The independent variable.
        A, B, mu, w : float
            Parameters of the sigmoid function.

    Returns:
        J : 2D array
            Jacobian matrix (m x n), where m is the number of data points
            and n is the number of parameters.
    """
    # Common terms
    exp_term = np.exp(-(x - mu) / w)
    denom = (1 + exp_term)**2
    sigmoid_derivative = exp_term / denom

    # Partial derivatives
    dA = 1 - sigmoid((x - mu) / w)  # Partial derivative w.r.t. A
    dB = sigmoid((x - mu) / w)      # Partial derivative w.r.t. B
    dmu = -(B - A) * sigmoid_derivative / w  # Partial derivative w.r.t. mu
    dw = -(B - A) * sigmoid_derivative * (x - mu) / (w**2)  # Partial derivative w.r.t. w

    # Combine into Jacobian matrix
    J = np.vstack((dA, dB, dmu, dw)).T  # Shape: (len(x), 4)
    return J

## Detectors

def detect_ldot(raw_sig, t, t_interpolator):
    sig = min_max_norm(-raw_sig)
    _pks, _pk_info = sp.find_peaks(sig, prominence=0.2, width=(None,None))

    pk_idx = _pks[-1]
    pk_widths = _pk_info['widths'][-1]
    
    if len(_pks) == 0:
        return None, None, None, None

    pk = t[pk_idx]

    try:
        left = t_interpolator(pk_idx - pk_widths).item()
    except ValueError:
        left = t[0]
    try:
        right = t_interpolator(pk_idx + pk_widths).item()
    except ValueError:
        right = t[-1]

    return pk_idx, pk, left, right

def detect_sigma(raw_sig, t, t_interpolator, idx_guess):
    sig = min_max_norm(raw_sig)
    _pks, _pk_info = sp.find_peaks(sig, prominence=0.1, width=(None,None))

    if len(_pks) == 0:
        return None, None, None
    
    pk_to_use = np.array([abs(_p - idx_guess) for _p in _pks]).argmin()

    pk_idx = _pks[pk_to_use]
    pk_widths = _pk_info['widths'][pk_to_use]
    
    pk = t[pk_idx]

    try:
        left = t_interpolator(pk_idx - pk_widths).item()
    except ValueError:
        left = t[0]
    try:
        right = t_interpolator(pk_idx + pk_widths).item()
    except ValueError:
        right = t[-1]

    return pk, left, right

def detect_k_inf(raw_sig, t, t_interpolator, idx_guess):
    sig = min_max_norm(raw_sig)
    _pks, _pk_info = sp.find_peaks(sig, width=(None,None), rel_height=0.25)

    if len(_pks) == 0:
        return None, None, None
    
    pk_to_use = np.array([abs(_p - idx_guess) for _p in _pks]).argmin()

    pk_idx = _pks[pk_to_use]
    pk_widths = _pk_info['widths'][pk_to_use]
    
    left = t[pk_idx]

    try:
        pk = t_interpolator(pk_idx + pk_widths).item()
    except ValueError:
        pk = t[0]
    try:
        right = t_interpolator(pk_idx + 2 * pk_widths).item()
    except ValueError:
        right = t[-1]

    return pk, left, right

def default_detect(raw_sig, t, t_interpolator, idx_guess):
    sig = min_max_norm(log_log_grad(raw_sig, t))
    _pks, _pk_info = sp.find_peaks(sig, width=(None,None))

    if len(_pks) == 0:
        return None, None, None
    
    pk_to_use = np.array([abs(_p - idx_guess) for _p in _pks]).argmin()

    pk_idx = _pks[pk_to_use]
    pk_widths = _pk_info['widths'][pk_to_use]
    
    pk = t[pk_idx]

    try:
        left = t_interpolator(pk_idx - pk_widths).item()
    except ValueError:
        left = t[0]
    try:
        right = t_interpolator(pk_idx + pk_widths).item()
    except ValueError:
        right = t[-1]

    return pk, left, right


def find_crossing(t, y, k_cutoff, idx_guess=None):
    raw_sig = pd.Series(y)
    raw_sig = raw_sig.expanding(1).std().fillna(0).values 
    df = pd.DataFrame({'time': t, 'values': y})
    df.set_index('time', inplace=True)
    interpolated_df = df.interpolate(method='linear')
    interpolated_values = interpolated_df['values'].to_numpy()
    _T = df.index.values
    _X = interpolated_values

    diffs = np.diff(_X < k_cutoff)
    indices = np.where(diffs)[0]

    sig_vals = raw_sig[indices]
    T_solutions = _T[indices] + (k_cutoff - _X[indices]) * (_T[indices + 1] - _T[indices]) / (_X[indices + 1] - _X[indices])

    if len(indices) > 1:
        if idx_guess is not None:
            diff = abs(T_solutions - idx_guess)
            best_ind = np.argmin(diff)
            return T_solutions[best_ind], sig_vals[best_ind]
        else:
            return T_solutions[0], sig_vals[0]

    return T_solutions[0], sig_vals[0]


def detect_min_cntk(raw_sig, t, t_interpolator, idx_guess):
    t_pk, unc = find_crossing(t, raw_sig, 0, idx_guess=idx_guess)
    t_left, _ = find_crossing(t, raw_sig, unc, idx_guess=t_pk)
    t_right, _ = find_crossing(t, raw_sig, -unc, idx_guess=t_pk)

    return t_pk, t_left, t_right


def fit_reg_sigmoid(raw_sig, t, idx_guess):
    l_bounds = [0, 0, 0, 0]
    u_bounds = [1, 1, 1e4, 1e3]
    
    sig = min_max_norm(raw_sig)
    if sig[-1] > sig[0]:
        _initial_guess = [0, 1, t[idx_guess], 20]
    else:
        _initial_guess = [1, 0, t[idx_guess], 20]

    popt, pcov = curve_fit(
        sigmoid_to_fit, t, sig, p0=_initial_guess, bounds=(l_bounds, u_bounds),
        jac=jac_sigmoid_to_fit
    )

    _A, _B, _mu, _w = popt
    pred = sigmoid_to_fit(t, _A, _B, _mu, _w)
    err = mean_squared_error(sig, pred)

    return popt, pcov, err

def fit_log_sigmoid(raw_sig, t, idx_guess):
    #bounds = init guess on [A, B, mu, w]
    l_bounds = [0, 0, 0, 0]
    #u_bounds = [1, 1, 1e4, 1e3]
    u_bounds = [1, 1, 4, 1]

    t_guess = np.log(1 + t[idx_guess]) / np.log(10)
    
    sig = min_max_norm(raw_sig)
    if sig[-1] > sig[0]:
        _initial_guess = [0, 1, t_guess, 0.25]
    else:
        _initial_guess = [1, 0, t_guess, 0.25]

    log_t = np.log(1 + t) / np.log(10)
    # return curve_fit(
    #     sigmoid_to_fit, log_t, sig, p0=_initial_guess, bounds=(l_bounds, u_bounds),
    #     jac=jac_sigmoid_to_fit
    # )

    popt, pcov = curve_fit(
        sigmoid_to_fit, log_t, sig, p0=_initial_guess, bounds=(l_bounds, u_bounds),
        jac=jac_sigmoid_to_fit
    )

    _A, _B, _mu, _w = popt
    pred = sigmoid_to_fit(log_t, _A, _B, _mu, _w)
    err = mean_squared_error(sig, pred)

    return popt, pcov, err

def detect_via_fit(raw_sig, t, t_interpolator, idx_guess):
    # l_bounds = [0, 0, 0, 0]
    # u_bounds = [1, 1, 1e4, 1e3]
    
    # sig = min_max_norm(raw_sig)
    # if sig[-1] > sig[0]:
    #     _initial_guess = [0, 1, t[idx_guess], 20]
    # else:
    #     _initial_guess = [1, 0, t[idx_guess], 20]

    # pop, pcov = curve_fit(sigmoid_to_fit, t, sig, p0=_initial_guess, bounds=(l_bounds, u_bounds))
    popt, pcov, err = fit_reg_sigmoid(raw_sig, t, idx_guess)

    
    if err > 0.01:
        return np.nan, np.nan, np.nan

    _A, _B, _mu, _w = popt
    _pk = _mu
    _left = _mu - 2 * _w
    _right = _mu + 2 * _w

    return _pk, _left, _right

def detect_via_log_fit(raw_sig, t, t_interpolator, idx_guess):
    popt, pcov, err = fit_log_sigmoid(raw_sig, t, idx_guess)

    if err > 0.01:
        return np.nan, np.nan, np.nan
    
    _A, _B, _mu, _w = popt
    
    _pk = 10 ** (_mu) - 1
    _left = 10 ** (_mu - 2 * _w)
    _right = 10 ** (_mu + 2 * _w)
    
    return _pk, _left, _right

def detect_magma_via_fit(raw_sig, t, t_interpolator, idx_guess):
    sig = pd.Series(raw_sig)
    sig = sig.expanding(1).std().fillna(0).values 
    return detect_via_fit(sig, t, t_interpolator, idx_guess)
    #return detect_via_log_fit(sig, t, t_interpolator, idx_guess)

def detect_magma_via_log_fit(raw_sig, t, t_interpolator, idx_guess):
    sig = pd.Series(raw_sig)
    sig = sig.expanding(1).std().fillna(0).values 
    #return detect_via_fit(sig, t, t_interpolator, idx_guess)
    return detect_via_log_fit(sig, t, t_interpolator, idx_guess)

def detect_magma_via_default(raw_sig, t, t_interpolator, idx_guess):
    sig = pd.Series(raw_sig)
    sig = sig.expanding(1).std().fillna(0).values 
    return default_detect(sig, t, t_interpolator, idx_guess)


DETECTORS = {
    '$\\sigma_{\\theta}$' : detect_sigma,
    '$\\min C_{NTK}$' : detect_min_cntk,
    '$\\text{CKA}(K_Y, K_{NTK})$' : detect_via_log_fit,
    'edge_alignment' : detect_via_log_fit,
    '$\\lambda_0$' : detect_via_log_fit,
    'norm_spatial_mean_log_D' : detect_magma_via_log_fit,
    '$\\mu_r K_{\\infty}$' : detect_k_inf
}


def create_sig_info_res(sig_name, pk, left, right):
    _info = {}
    _info[f'{sig_name} pk'] = pk
    _info[f'{sig_name} start'] = left
    _info[f'{sig_name} end'] = right
    return _info


def extract_pk_information(config_df, param_arrays, Ldot, T):
    """
    Extracts the pk information from the given configuration dataframe and parameter arrays.
    
    Parameters:
        config_df (pd.DataFrame): Dataframe containing the configuration information.
        param_arrays (dict): Dictionary containing the parameter arrays.
        Ldot (np.ndarray): Array containing the the loss velocity values ie dL/dt.
        T (np.ndarray): Array containing the epoch values.
    """
    t_interpolator = interp1d(np.arange(len(T)), T, kind='linear')
    pk_information = []
    for i, row in tqdm(config_df.iterrows(), total=len(config_df)):
        _info = {}
        _config = row[_ALL_FILTERS].to_dict()
        _info.update(_config)

        _ldot = Ldot[i]
        _ldot_pk_idx, _ldot_pk, _ldot_left, _ldot_right = detect_ldot(_ldot, T, t_interpolator)

        sig_name = '$\\dot{L}_{\\text{eval}}$' 
        _info.update(
            create_sig_info_res(sig_name, _ldot_pk, _ldot_left, _ldot_right)
        )

        for sig_name in param_arrays.keys():
            if sig_name == '$\\dot{L}_{\\text{eval}}$':
                continue
            else:
                try:
                    sig_to_study = param_arrays[sig_name][i]
                    _detector = DETECTORS[sig_name]
                    _pk, _left, _right = _detector(sig_to_study, T, t_interpolator, _ldot_pk_idx)
                    _info.update(
                        create_sig_info_res(sig_name, _pk, _left, _right)
                    )
                except Exception as e:
                    _info.update(
                        create_sig_info_res(sig_name, None, None, None)
                    )

        pk_information.append(_info)

    pk_information = pd.DataFrame(pk_information)

    