import numpy as np
import time
from sklearn.neural_network import MLPRegressor
from funcs.utils import robust_RLM, robust_SGD
from funcs.utils import counts_int, counts_set, majority_vote, exch_majority_vote

def OracleCP(D, D_test, A, alpha=0.1, params_rlm = {'thres': 1, 'lamb': 1}, params_sgd = {'shuffles': None, 'lr': 0.01, 'epochs': 1}):
    X, Y = D[0], D[1]; n = len(Y)
    X_test, Y_test = D_test[0], D_test[1]; m = len(Y_test)
    Coverages = np.zeros(m); Sizes = np.zeros(m)

    start = time.time()
    for i, x_test in enumerate(X_test):
        y_test = Y_test[i]
        X_aug = X.copy(); Y_aug = Y.copy()
        X_aug = np.vstack([X_aug, x_test]); Y_aug = np.append(Y_aug, y_test)
        if A == 'RLM':
            theta = robust_RLM((X_aug, Y_aug), thres=params_rlm['thres'], lamb=params_rlm['lamb'])
        elif A == 'SGD':
            theta = robust_SGD((X_aug, Y_aug), shuffles=params_sgd['shuffles'], lr=params_sgd['lr'], epochs=params_sgd['epochs'])
        Y_pred = X_aug @ theta
        S = np.abs(Y_pred - Y_aug)
        Q = np.quantile(S, 1-alpha, interpolation='higher')
        Coverages[i] = np.abs(Y_pred[n] - y_test) <= Q
        Sizes[i] = 2 * Q
        
    end = time.time()
    Time = end - start
    Coverage = Coverages.mean()
    Size = Sizes.mean()
    return Coverage, Size, Time

def FullCP(D, D_test, A, alpha=0.1, params_rlm = {'thres': 1, 'lamb': 1}, params_sgd = {'shuffles': None, 'lr': 0.01, 'epochs': 1}):
    X, Y = D[0], D[1]; n = len(Y)
    X_test, Y_test = D_test[0], D_test[1]; m = len(Y_test)
    Z = np.linspace(Y.min() - Y.std(), Y.max() + Y.std(), 100)
    Coverages = np.zeros(m); Sizes = np.zeros(m)

    start = time.time()
    for i, x_test in enumerate(X_test):
        y_test = Y_test[i]
        FCP = [-np.inf, np.inf]; pointer = 0
        while FCP[0] == -np.inf or FCP[1] == np.inf:
            z = Z[pointer]
            X_aug = X.copy(); Y_aug = Y.copy()
            X_aug = np.vstack([X_aug, x_test]); Y_aug = np.append(Y_aug, z)
            
            if A == 'RLM':
                theta = robust_RLM((X_aug, Y_aug), thres=params_rlm['thres'], lamb=params_rlm['lamb'])
            elif A == 'SGD':
                theta = robust_SGD((X_aug, Y_aug), shuffles=params_sgd['shuffles'], lr=params_sgd['lr'], epochs=params_sgd['epochs'])
            Y_pred = X_aug @ theta
            S = np.abs(Y_pred - Y_aug)
            Q = np.quantile(S, 1-alpha, interpolation='higher')
            
            if FCP[0] == -np.inf:
                if S[n] <= Q:
                    FCP[0] = z
                    pointer = len(Z) - 1
                    continue
                else:
                    pointer += 1
                    continue

            if FCP[1] == np.inf:
                if S[n] <= Q:
                    FCP[1] = z
                    continue
                else:
                    pointer -= 1
                    continue
        Coverages[i] = (y_test >= FCP[0]) & (y_test <= FCP[1])
        Sizes[i] = FCP[1] - FCP[0]
    
    end = time.time()
    Time = end - start
    Coverage = Coverages.mean()
    Size = Sizes.mean()
    return Coverage, Size, Time

def SplitCP(D, D_test, A, alpha=0.1, params_rlm = {'thres': 1, 'lamb': 1}, params_sgd = {'shuffles': None, 'lr': 0.01, 'epochs': 1}, params_nn = {'hidden': 20, 'lr': 0.01, 'epochs': 1},  Multi = False):
    X, Y = D[0], D[1]; n = len(Y)
    
    n_train = int(n * 0.7)
    ind_train = np.random.choice(n, n_train, replace=False)
    ind_calib = np.setdiff1d(np.arange(n), ind_train)
    X_train, Y_train = X[ind_train], Y[ind_train]
    X_calib, Y_calib = X[ind_calib], Y[ind_calib]
    X_test, Y_test = D_test[0], D_test[1]; m = len(Y_test)
    Coverages = np.zeros(m)

    start = time.time()
    if A == 'RLM':
        theta = robust_RLM((X_train, Y_train), thres=params_rlm['thres'], lamb=params_rlm['lamb'])
    elif A == 'SGD':
        theta = robust_SGD((X_train, Y_train), shuffles=params_sgd['shuffles'], lr=params_sgd['lr'], epochs=params_sgd['epochs'])
    elif A == 'NN':
        model = MLPRegressor(hidden_layer_sizes=tuple(params_nn['hidden']), activation='relu', 
                   solver='sgd', shuffle = True, learning_rate='constant', max_iter=params_nn['epochs'], 
                   learning_rate_init=params_nn['lr'], tol=1e-4, verbose=0, random_state=42)
        model.fit(X_train, Y_train)
        Y_calib_pred = model.predict(X_calib)
        Y_test_pred = model.predict(X_test)
    if A != 'NN':
        Y_calib_pred = X_calib @ theta
        Y_test_pred = X_test @ theta
        
    S = np.abs(Y_calib_pred - Y_calib)
    Q = np.quantile(S, 1-alpha, interpolation='higher')
    
    for i, y_test in enumerate(Y_test):
        Coverages[i] = np.abs(Y_test_pred[i] - y_test) <= Q
    end = time.time()

    Time = end - start
    Coverage = Coverages.mean()
    Size = 2 * Q
    if Multi:
        lim = np.zeros((m, 2))
        for i, y_test in enumerate(Y_test):
            lim[i] = np.array([Y_test_pred[i] - Q, Y_test_pred[i] + Q])
        return Coverage, Size, Time, lim
    else:
        return Coverage, Size, Time


def MultiSplitCP(D, D_test, A, alpha=0.1, method = "M", tau = 0.5, K = 30, params_rlm = {'thres': 1, 'lamb': 1}, params_sgd = {'shuffles': None, 'lr': 0.01, 'epochs': 1}):
    m = len(D_test[1])
    lims = np.zeros((m,K,2))
    agg_lims = np.zeros((m,2))
    
    start = time.time()
    for k in range(K):
        _, _, _, lim = SplitCP(D, D_test, A, alpha=alpha*(1-tau), params_rlm=params_rlm, params_sgd=params_sgd, Multi=True)
        lims[:,k,:] = lim
    
    if method == "M":
        for i in range(m):
            agg_lims[i,:] = majority_vote(lims[i,:,:])
    elif method == "E":
        for i in range(m):
            agg_lims[i,:] = exch_majority_vote(lims[i,:,:])
    end = time.time()
    
    Coverages = np.zeros(m)
    Sizes = np.zeros(m)
    for i, y_test in enumerate(D_test[1]):
        Coverages[i] = (y_test >= agg_lims[i,0]) & (y_test <= agg_lims[i,1])
        Sizes[i] = agg_lims[i,1] - agg_lims[i,0]
    
    Coverage = Coverages.mean()
    Size = Sizes.mean()
    Time = end - start
    
    return Coverage, Size, Time

def RO_StabCP(D, D_test, A, alpha=0.1, params_rlm = {'thres': 1, 'lamb': 1}, params_sgd = {'shuffles': None, 'lr': 0.01, 'epochs': 1}, params_nn = {'hidden': 20, 'lr': 0.01, 'epochs': 1}, kernel = False):
    X, Y = D[0], D[1]; n = len(Y)
    X_test, Y_test = D_test[0], D_test[1]; m = len(Y_test)
    zhat = 0
    if A == "RLM" or A == "SGD" or A == "NN":
        nu = np.linalg.norm(X, axis=1)
        eps = params_rlm['thres']
        if A == "RLM":
            lamb = params_rlm['lamb']
        if A == "SGD":
            eta = params_sgd['lr']
    Coverages = np.zeros(m); Sizes = np.zeros(m)

    start = time.time()
    for i, x_test in enumerate(X_test):
        y_test = Y_test[i]
        X_aug = X.copy(); Y_aug = Y.copy()
        X_aug = np.vstack([X_aug, x_test]); Y_aug = np.append(Y_aug, zhat)
        if A == "RLM" or A == "SGD":
            nu_new = np.linalg.norm(x_test)
            nu_aug = np.append(nu, nu_new)
            rho_new = eps*nu_new.copy()
            if A == "RLM":
                tau_aug = 4*(nu_aug*rho_new)/(n+1)/lamb
                if kernel:
                    tau_aug /= np.linalg.eigvalsh(X).min()
                theta = robust_RLM((X_aug, Y_aug), thres=params_rlm['thres'], lamb=params_rlm['lamb'])
            elif A == "SGD":
                R = params_sgd['epochs']
                tau_aug = 2*R*eta*rho_new*nu_aug
                theta = robust_SGD((X_aug, Y_aug), shuffles=params_sgd['shuffles'], lr=params_sgd['lr'], epochs=params_sgd['epochs'])
            Y_pred = X_aug @ theta
        elif A == "NN":
            model = MLPRegressor(hidden_layer_sizes=tuple(params_nn['hidden']), activation='relu', 
                   solver='sgd', shuffle = True, learning_rate='constant', max_iter=params_nn['epochs'], 
                   learning_rate_init=params_nn['lr'], tol=1e-4, verbose=0, random_state=42)
            model.fit(X_aug, Y_aug)
            nu_new = np.linalg.norm(x_test)
            nu_aug = np.append(nu, nu_new)
            Y_pred = model.predict(X_aug)
            tau_aug = 2*np.repeat(params_nn['lr'] * params_nn['epochs'], n+1)
            
        S = np.abs(Y_pred[:n] - Y_aug[:n])
        U = S + tau_aug[:n]
        Q = np.quantile(U, 1-alpha, interpolation='higher') + tau_aug[n]
        
        Coverages[i] = np.abs(Y_pred[n] - y_test) <= Q
        Sizes[i] = 2 * Q
    
    end = time.time()
    Time = end - start
    Coverage = Coverages.mean()
    Size = Sizes.mean()
    return Coverage, Size, Time


def LOO_StabCP(D, D_test, A, alpha=0.1, params_rlm = {'thres': 1, 'lamb': 1}, params_sgd = {'shuffles': None, 'lr': 0.01, 'epochs': 1}, params_nn = {'hidden': 20, 'lr': 0.01, 'epochs': 1}, kernel = False):
    X, Y = D[0], D[1]; n = len(Y)
    X_test, Y_test = D_test[0], D_test[1]; m = len(Y_test)
    if A == "RLM" or A == "SGD" or A == "NN":
        nu = np.linalg.norm(X, axis=1)
        eps = params_rlm['thres']
        if A == "RLM":
            lamb = params_rlm['lamb']
            rho = eps*nu.copy()
            rhobar = np.mean(rho)
        elif A == "SGD":
            eta = params_sgd['lr']
    Coverages = np.zeros(m); Sizes = np.zeros(m)

    start = time.time()
    if A == "RLM":
        theta = robust_RLM((X, Y), thres=params_rlm['thres'], lamb=params_rlm['lamb'])
        Y_pred = X @ theta
    elif A == "SGD":
        R = params_sgd['epochs']
        theta = robust_SGD((X, Y), shuffles=params_sgd['shuffles'], lr=params_sgd['lr'], epochs=params_sgd['epochs'])
        Y_pred = X @ theta
    elif A == "NN" :
        model = MLPRegressor(hidden_layer_sizes=tuple(params_nn['hidden']), activation='relu', 
                   solver='sgd', shuffle = True, learning_rate='constant', max_iter=params_nn['epochs'], 
                   learning_rate_init=params_nn['lr'], tol=1e-4, verbose=0, random_state=42)
        model.fit(X, Y)
        Y_pred = model.predict(X)
        Y_test_pred = model.predict(X_test)
    S = np.abs(Y_pred - Y)
    
    
    
    for i, x_test in enumerate(X_test):
        y_test = Y_test[i]
        if A == "RLM" or A == "SGD":
            nu_new = np.linalg.norm(x_test)
            nu_aug = np.append(nu, nu_new)
            rho_new = eps*nu_new.copy()
            if A == "RLM":
                tau_aug = 2*nu_aug*(rho_new + rhobar)/(n+1)/lamb
                if kernel:
                    tau_aug /= np.linalg.eigvalsh(X).min()
            elif A == "SGD":
                tau_aug = R*eta*rho_new*nu_aug
            y_test_pred = x_test.reshape(1, -1) @ theta
            y_test_pred = y_test_pred[0]
        elif A == "NN":
            nu_new = np.linalg.norm(x_test)
            nu_aug = np.append(nu, nu_new)
            tau_aug = np.repeat(params_nn['lr'] * params_nn['epochs'], n + 1)
            y_test_pred = Y_test_pred[i]
            
        U = S + tau_aug[:n]
        Q = np.quantile(U, 1-alpha, interpolation='higher') + tau_aug[n]

        Coverages[i] = np.abs(y_test_pred - y_test) <= Q
        Sizes[i] = 2 * Q
    
    end = time.time()
    Time = end - start
    Coverage = Coverages.mean()
    Size = Sizes.mean()
    return Coverage, Size, Time
