import numpy as np
import cvxpy as cp
import pandas as pd
from sklearn.datasets import load_diabetes, fetch_openml

def robust_RLM(D, thres = 1, lamb = 1, kernel = False):
    X, Y = D
    n, p = X.shape
    
    theta = cp.Variable(p)
    lmd = cp.Parameter(nonneg=True)
    lmd.value = lamb

    def objective(X, Y, theta, thres, lmd, kernel = kernel):
        if not kernel:
            return cp.mean(cp.huber(Y - X @ theta, thres))/2 + lamb/2 * cp.norm2(theta)**2
        else:
            return cp.mean(cp.huber(Y - X @ theta, thres))/2 + lamb/2 * theta.T @ X @ theta
        
    problem = cp.Problem(cp.Minimize(objective(X, Y, theta, thres, lmd)))
    problem.solve()

    return theta.value

def robust_SGD(D, shuffles, thres=1, lr=0.001, epochs=1):
    X, Y = D
    n, p = X.shape
    theta = np.zeros(p)
    for epoch in range(epochs):
        for i in shuffles[epoch]:
            if i >= n:
                continue
            x = X[i]
            y = Y[i]
            prediction = x @ theta
            residual = y - prediction
            if abs(residual) <= thres:
                gradient = residual * x
            else:
                gradient = thres * np.sign(residual) * x
            theta += lr * gradient
    return theta

def kernel_matrix(X1, X2, kernel = "RBF", params_RBF = {'gamma': 0.1}, params_poly = {'degree' : 2, 'coef0': 1}):
    if kernel == "RBF":
        gamma = params_RBF['gamma']
        K = np.exp(-gamma * np.linalg.norm(X1[:, None] - X2[None, :], axis=2) ** 2)
    elif kernel == "poly":
        degree = params_poly['degree']
        coef0 = params_poly['coef0']
        K = (X1 @ X2.T + coef0)**degree
    return K

def counts_int(M, a, b, w):
    num_int = np.average((M[:, 0] <= (a + b) * 0.5) & ((a + b) * 0.5 <= M[:, 1]), weights=w)
    return num_int

def counts_set(M, a, b):
    k = len(M)
    vote = np.zeros(k, dtype=int)
    
    for i in range(k):
        if M[i] is None or len(M[i]) == 0:
            vote[i] = 0
        else:
            midpoint = (a + b) * 0.5
            vote[i] = np.sum((M[i][:, 0] <= midpoint) & (midpoint <= M[i][:, 1]))
    
    num_int = np.sum(vote)
    return num_int

def majority_vote(M, w = [], tau=0.5):
    if len(w) == 0:
        w = np.ones(M.shape[0]); w /= np.sum(w)
    breaks = np.unique(M.flatten())
    breaks.sort()
    i = 0
    lower, upper = [], []

    while i < len(breaks) - 1:
        cond = counts_int(M, breaks[i], breaks[i + 1], w) > tau
        if cond:
            lower.append(breaks[i])
            j = i
            while j < len(breaks) - 1 and cond:
                j += 1
                cond = counts_int(M, breaks[j], breaks[j + 1], w) > tau
            i = j
            upper.append(breaks[i])
        i += 1
    
    if not lower:
        return None
    else:
        return np.column_stack((lower, upper))
    
def exch_majority_vote(M, tau=0.5):
    k = M.shape[0]
    if k == 1:
        return M

    perm = np.random.permutation(k)
    permM = M[perm, :]

    newM = [None] * k
    newM[0] = permM[0, :].reshape(1, 2)

    for i in range(1, k):
        weights = np.full(i + 1, 1 / (i + 1))
        newM[i] = majority_vote(permM[:i + 1, :], weights, tau)
        if newM[i] is None or len(newM[i]) == 0:
            return None

    breaks = np.unique(np.concatenate([matrix.flatten() for matrix in newM if matrix is not None]))
    breaks.sort()

    lower, upper = [], []
    i = 0

    while i < len(breaks) - 1:
        cond = counts_set(newM, breaks[i], breaks[i + 1]) == k
        if cond:
            lower.append(breaks[i])
            j = i
            while j < len(breaks) - 1 and cond:
                j += 1
                cond = counts_set(newM, breaks[j], breaks[j + 1]) == k
            i = j
            upper.append(breaks[i])
        i += 1

    if not lower:
        return None
    else:
        return np.column_stack((lower, upper))


def BH(pval, q = 0.1):
    ntest = len(pval)
    
    df_test = pd.DataFrame({"id": range(ntest), "pval": pval}).sort_values(by='pval')
    df_test['threshold'] = q * np.linspace(1, ntest, num=ntest) / ntest 
    idx_smaller = [j for j in range(ntest) if df_test.iloc[j,1] <= df_test.iloc[j,2]]
    
    if len(idx_smaller) == 0:
        return(np.array([]))
    else:
        idx_sel = np.array(df_test.index[range(np.max(idx_smaller)+1)])
        return(idx_sel)

def DGP_lin(n,m,p,sigma=1,cov="indep"):
    Z = np.random.randn(n, p)
    Z_test = np.random.randn(m, p)
    if cov == "indep":    
        X = Z; 
        X_test = Z_test; 
    elif cov == "ar1":
        Sigma = np.fromfunction(lambda i, j: 0.5 ** np.abs(i - j), (p, p))
        L = np.linalg.cholesky(Sigma)
        X = Z @ L.T
        X_test = Z_test @ L.T
    X = X/np.sqrt(p)
    X_test = X_test/np.sqrt(p)
    theta = np.array([1 - i/p for i in range(1,p+1)])**5
    theta = theta/np.linalg.norm(theta)/np.sqrt(p)
    Y = X @ theta + np.random.randn(n) * sigma
    Y_test = X_test @ theta + np.random.randn(m) * sigma
    D = (X,Y)
    D_test = (X_test,Y_test)
    return D,D_test

def DGP_nonlin(n,m,p,sigma=1,cov="indep"):
    Z = np.random.randn(n, p)
    Z_test = np.random.randn(m, p)
    if cov == "indep":    
        X = Z; 
        X_test = Z_test; 
    elif cov == "ar1":
        Sigma = np.fromfunction(lambda i, j: 0.5 ** np.abs(i - j), (p, p))
        L = np.linalg.cholesky(Sigma)
        X = Z @ L.T
        X_test = Z_test @ L.T
    X = X/np.sqrt(p)
    X_test = X_test/np.sqrt(p)
    theta = np.array([1 - i/p for i in range(1,p+1)])**5
    theta = theta/np.linalg.norm(theta)/np.sqrt(p)
    Y = (np.exp(X/10)) @ theta + np.random.randn(n) * sigma
    Y_test = (np.exp(X_test/10)) @ theta + np.random.randn(m) * sigma
    D = (X,Y)
    D_test = (X_test,Y_test)
    return D,D_test
    
def get_data(dataname):
    if dataname == 'boston':
        dataset = fetch_openml(name='boston', version=1)
    elif dataname == 'diabetes':
        dataset = load_diabetes()
    X = np.array(dataset.data, dtype=float); Y = np.array(dataset.target)
    mx = np.mean(X, axis=0); my = np.mean(Y)
    sx = np.std(X, axis=0); sy = np.std(Y)
    X = (X - mx) / sx / np.sqrt(X.shape[1]); Y = (Y - my) / sy
    return X, Y
