import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from numpy import ndarray
import statsmodels.api as sm
import sys
from sklearn.linear_model import Ridge
from sklearn.linear_model import LogisticRegression



def dpdx(xr:ndarray, beta:ndarray):
    """
    function to calculate derivative of P

    :param xr: observation
    :param beta: logistic regression coefficient vector
    :return: derivative at xr
    """
    y = np.dot(xr, beta)
    if y > 700:
        lnP_ = -y
    else:
        lnP_ = y - 2 * np.log(1 + np.exp(y))
    return np.exp(lnP_)

def rsquared(y: ndarray, ypred: ndarray):
    """
    function to calculate rsquared

    :param y: target
    :param ypred: prediction
    :return: rsquared
    """

    assert len(y) == len(ypred), 'y and ypred should be same length'
    return 1 - (np.linalg.norm(y - ypred) ** 2) / (np.linalg.norm(y - np.mean(y)) ** 2)


def lnL(y: ndarray, ypred: ndarray):
    """
    function to calculate log likelihood

    :param y: target
    :param ypred: prediction
    :return: log likelihood
    """

    assert len(y) == len(ypred), 'y and ypred should be same length'
    res = 0
    for i in range(len(y)):
        res += (y[i]*np.log(np.clip(ypred[i], min=sys.float_info.min)) + (1-y[i])*np.log(np.clip(1-ypred[i], min=sys.float_info.min)))
    return res


def pseudo_rsquared(y: ndarray, ypred: ndarray, lnM:float):
    """
    function to calculate pseudo-rsquared for logit

    :param y: target
    :param ypred: prediction
    :param lnM: constant from null model
    :return: pseudo-rsquared
    """

    assert len(y) == len(ypred), 'y and ypred should be same length'
    return 1 - np.exp(2/len(y)*(lnL(y,ypred)-lnM))

def ols_predict(beta:ndarray, X:ndarray, offset:float):
    """
    function to calculate OLS prediction

    :param beta: coefficient vector
    :param X: observations
    :param offset: offset vector
    :return: predictions
    """

    assert len(beta) == len(X[0,:]), 'input dimensions do not match'
    return np.matmul(X, beta) + offset

def logit_predict(beta:ndarray, X:ndarray):
    """
    function to predict probability

    :param beta: coefficient vector
    :param X: observations
    :return: probability prediction
    """

    assert len(beta) == len(X[0,:]), 'input dimensions do not match'
    X = sm.add_constant(X, has_constant='skip')
    p = np.zeros((len(X[:,0]),1))
    for i in range(len(X[:,0])):
        p[i] = 1/(1+np.exp(-np.dot(X[i,:],beta)))
    return p

def f_Ols(X:pd.DataFrame, y:pd.Series, alpha:float=0.05, max_iter:int=100):
    """
    function to fit OLS

    :param X: observations
    :param y: target
    :param alpha: regularization factor
    :param max_iter: maximum iterations allowed
    :return: coefficients
             converged
    """

    model = Ridge(alpha=alpha, fit_intercept=False, max_iter=max_iter, random_state=42, solver='lsqr')

    model.fit(X, y)

    return model.coef_, model.n_iter_ < max_iter

def f_Logit(X:pd.DataFrame, y:pd.Series, alpha:float=0.05, max_iter:int=100):
    """
    function to fit Logit

    :param X: observations
    :param y: target
    :param alpha: regularization factor
    :return: coefficients
             converged
    """

    model = LogisticRegression(penalty='l2', C=1/alpha, solver='saga', max_iter=max_iter, random_state=42)
    model.fit(X, y)

    return np.concatenate((np.array(model.intercept_), model.coef_.flatten())), model.n_iter_[0]<max_iter


def pairwise_d(datapoints: list):
    """
    function to calculate sum of distances within a cluster

    :param datapoints: clusters
    :return: sum of distances
    """

    if len(datapoints) == 1:
        return 0
    d = 0
    for i in range(len(datapoints)):
        for j in range(len(datapoints)):
            if i == j:
                pass
            else:
                d += abs(datapoints[i]-datapoints[j])
    return d/2


def gap_stat(clustered_dps:dict, low:float, high:float, n_simu:int=100):
    """
    function to calculate gap statistics

    :param clustered_dps: clusters
    :param low: lower bound of simulated data points
    :param high: higher bound of simulated data points
    :param n_simu: number of times to draw data points from simulation
    :return: gap-statistics
    """

    k = len(clustered_dps.keys())
    sample_logW = []
    for i in range(n_simu):
        cluster = pd.DataFrame(columns=['datapoint', 'label'])
        sample = np.random.uniform(low, high, 50*k)
        cluster.datapoint = sample.tolist()
        sample = sample.reshape(-1,1)
        model = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(sample)
        cluster.label = model.predict(sample)
        W = 0
        for c in cluster.label.unique():
            group = cluster.loc[cluster['label']==c,'datapoint'].tolist()
            W += pairwise_d(group)/(2*len(group))

        sample_logW.append(np.log(W))

    E_logW = np.mean(sample_logW)

    t = 0
    for c in clustered_dps.keys():
        t += pairwise_d(clustered_dps[c]) / (2 * len(clustered_dps[c]))

    if t == 0:
        return np.inf, np.std(sample_logW)*np.sqrt(1+1/n_simu)
    else:
        return E_logW - np.log(t), np.std(sample_logW)*np.sqrt(1+1/n_simu)