# -*- coding: utf-8 -*-
"""
Author: Jiamei Wu
file: Reproducible_Experiment1.py
times: 8/3/2022 09:02 AM
"""

##the small experiment on simulated data


import numpy as np
import warnings
warnings.filterwarnings('ignore')
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
np.warnings.filterwarnings('ignore')


def f(x):
    ''' Construct data (1D example)
    '''
    ax = 0*x
    eps = 0*x
    pred = 0*x
    true_coefficient = 3.0
    intercept = 2.0
    noise_variance = 15.0  # 增加噪声的方差，以增加偏差
    for i in range(len(x)):
        eps[i] = np.random.normal(0, np.sqrt(noise_variance))
        ax[i] = true_coefficient * x[i] + intercept + eps[i]
        pred[i] =  true_coefficient * x[i] + intercept
    return ax.astype(np.float32), eps.astype(np.float32), pred.astype(np.float32)


def tightSVD(Mat):
    from numpy.linalg import svd
    from numpy import eye
    P_r, Lamb, Q_Tr = svd(Mat) #n*n, n*p, p*p
    move = 0
    while (move < Lamb.shape[0] and Lamb[move] > 0.000001) :
        move += 1 #r<=n
    P = P_r[:, 0 : move] #n*r
    Q = Q_Tr.T[:, 0 : move] #p*r
    Q_perp = Q_Tr.T[:, move :] if move < Mat.shape[1] else None
    lb = eye(move) 
    for i in range(move):
        lb[i, i] = Lamb[i]
    lmin = Lamb[move - 1]
    return lb, move, P, Q, Q_perp

def calupQuantile(quans, alpha):
    from numpy import sort
    import math
    if quans.shape[0] == 0:
        return 0.0
    rets = sort(quans)  # Small to big
    size = math.ceil(alpha * quans.shape[0])-1#math.ceil 
    crucial1 = rets[size]
    return crucial1

def calloQuantile(quans, alpha):
    from numpy import sort
    import math
    if quans.shape[0] == 0:
        return 0.0
    rets = sort(quans)  # Small to big
    size = math.floor(alpha * quans.shape[0])-1#math.ceil 
    crucial2 = rets[size]
    return crucial2

def CalThHat(X,y,rho,an):
    import numpy as np
    from numpy.linalg import inv
    from numpy import eye
    lb,r,P,Q,Q_perp = tightSVD(X)
    n,p = X.shape
    LambAdj1 = np.linalg.inv(lb * lb +rho*eye(r))
    plamb = np.matmul(P,lb)
    LambAdj2 = np.matmul(np.matmul(Q, LambAdj1), plamb.T)
    dotproduct = np.matmul(LambAdj2,y)
    condition = np.abs(dotproduct) > an
    LambAdj2[~condition, :] = 0
    dotproduct[~condition] = 0
    ThHat = np.matmul(X, LambAdj2)
    return ThHat, dotproduct


def CalThHatstar(X, y, rho, an):
    import numpy as np
    from numpy.linalg import inv
    from numpy import eye
    lb,r,P,Q,Q_perp = tightSVD(X)
    LambAdj1 = np.linalg.inv(lb * lb +rho*eye(r))
    plamb = np.matmul(P,lb)
    LambAdj2 = np.matmul(np.matmul(Q, LambAdj1), plamb.T)
    dotproduct = np.matmul(LambAdj2,y)
    condition = np.abs(dotproduct) > an 
    LambAdj3 = lb * lb + 2*rho*eye(r)
    LambAdj4 = np.matmul(Q, LambAdj3)
    LambAdj5 = LambAdj1 * LambAdj1
    LambAdj6 = np.matmul(np.matmul(LambAdj4, LambAdj5), plamb.T)
    LambAdj6[~condition, :] = 0
    ThHatstar = np.matmul(X, LambAdj6)
    dotproduct2 = np.matmul(LambAdj6,y)
    dotproduct2[~condition] = 0
    return ThHatstar, dotproduct2

def CrossV(X, y, kfold, rhos, bs):
    from sklearn.model_selection import KFold
    import numpy as np
    kFold = KFold(n_splits = kfold)
    size_rho, size_b = rhos.shape[0], bs.shape[0]
    theRidRes = np.zeros((size_rho, size_b))
    move = 0
    for train, test in kFold.split(X):
        trainX = X[train]
        testX = X[test]
        trainy = y[train]
        testy = y[test]
        for i in range(size_rho):
            for j in range(size_b):
                theRidRes[i,j] += np.linalg.norm(testy - np.matmul(testX,CalThHat(trainX, trainy, rhos[i],bs[j])[1]))
        move += 1
        print(i)
    optRidResindex = np.unravel_index(np.argmin(theRidRes), theRidRes.shape)
    
    optRidResrho = rhos[optRidResindex[0]]
    optRidResb = bs[optRidResindex[1]]
    return optRidResrho, optRidResb

def CRR(x_train, 
        y_train, 
        x_test, 
        y_test, 
        alpha, 
        kfold):
     
     """Estimation of prediction and coverage

     Parameters
     ----------

     x_train : number array 
     y_train : number array
     x_test  : number array
     y_test  : number array
     aplha   : significance
     kfold   : int

     """

     from numpy import eye
     rhos = np.arange(0,300,1)
     bs = 0
     #rho = CrossV1(x_train, y_train, kfold, rhos)[0]
     rho = 3000
     n = len(y_train)
     yn = np.r_[y_train,0]
     yf = np.r_[np.zeros(n),1]

     freq = 0
     quantile = []
     y_upper = []
     y_lower = []
     pred = []
     for i in range(len(y_test)):
        print("y_test[",i,"]:",y_test[i])
        Xnew = np.concatenate((x_train, x_test[i,:].reshape(1,-1)), axis=0)
        ynew = np.concatenate((y_train, [y_test[i]]))
        Hatmatrix = CalThHat(Xnew,ynew,rho,bs)[0]
        ihat = eye(n+1) - Hatmatrix
        A = np.matmul(ihat, yn)
        B = np.matmul(ihat, yf)
        u1 = np.zeros(n+1)
        for j in range(n+1):
            if B[n]-B[j]>0:
                value = (A[j]-A[n])/(B[n]-B[j])
                u1[j] = value
        y_l = calloQuantile(u1, alpha/2)
        y_u = calupQuantile(u1, 1-alpha/2)
        if y_test[i]<=y_u and y_test[i]>=y_l:
            freq += 1
        y_upper.append(y_u)
        y_lower.append(y_l)
        quantile.append(y_u-y_l)
        pred1 = np.matmul(Hatmatrix, ynew)[-1]
        pred.append(pred1)
        print("inner,",u1)
     return freq, y_upper, y_lower, quantile, pred

def DeCRR(x_train, 
        y_train, 
        x_test, 
        y_test, 
        alpha, 
        kfold):
     
     """Estimation of prediction and coverage

     Parameters
     ----------

     x_train : number array 
     y_train : number array
     x_test  : number array
     y_test  : number array
     aplha   : significance
     kfold   : int

     """

     from numpy import eye
     rhos = np.arange(0,300,1)
     bs = 0
     rho = 3000
     n = len(y_train)
     yn = np.r_[y_train,0]
     yf = np.r_[np.zeros(n),1]

     freq = 0
     quantile = []
     y_upper = []
     y_lower = []
     pred = []
     for i in range(len(y_test)):
        print("y_test[",i,"]:",y_test[i])
        Xnew = np.concatenate((x_train, x_test[i,:].reshape(1,-1)), axis=0)
        ynew = np.concatenate((y_train, [y_test[i]]))
        Hatmatrix = CalThHatstar(Xnew,ynew,rho,bs)[0]
        ihat = eye(n+1) - Hatmatrix
        A = np.matmul(ihat, yn)
        B = np.matmul(ihat, yf)
        u1 = np.zeros(n+1)
        for j in range(n+1):
            if B[n]-B[j]>0:
                value = (A[j]-A[n])/(B[n]-B[j])
                u1[j] = value
        y_l = calloQuantile(u1, alpha/2)
        y_u = calupQuantile(u1, 1-alpha/2)
        if y_test[i]<=y_u and y_test[i]>=y_l:
            freq += 1
        y_upper.append(y_u)
        y_lower.append(y_l)
        quantile.append(y_u-y_l)
        pred1 = np.matmul(Hatmatrix, ynew)[-1]
        pred.append(pred1)  
     return freq, y_upper, y_lower, quantile, pred

        
            
         

# desired miscoverage error
alpha = 0.1


# save figures?
save_figures = False

# parameters of random forests
n_estimators = 100
min_samples_leaf = 40
max_features = 1 # 1D signal
random_state = 0

# number of training examples
n_train = 200
# number of test examples (to evaluate average coverage and length)
n_test = 100
max_show = n_test

# training features
x_train = np.random.uniform(0, 5.0, size=n_train).astype(np.float32)

# test features
x_test = np.random.uniform(0, 5.0, size=n_test).astype(np.float32)

# generate labels
y_train = f(x_train)[0]
y_test,eps_test,pred3 = f(x_test)

# reshape the features
x_train = np.reshape(x_train,(n_train,1))
x_test = np.reshape(x_test,(n_test,1))

#CRR
freq1, y_upper1, y_lower1, quantile1, pred1= CRR(x_train, y_train, x_test, y_test, 0.1, 5)
if y_upper1 is not None:
    y_u_ = np.array(y_upper1[:]).astype(y_test.dtype)
if y_lower1 is not None:
    y_l_ = np.array(y_lower1[:]).astype(y_test.dtype)
if pred1 is not None:
    pred_ = np.array(pred1[:])

shade_color = 'lightblue'
method_name="CRR"
filename="illustration_crr.png"
save_figures=save_figures
title="CRR"

fig = plt.figure()
inds = np.argsort(np.squeeze(x_test))
plt.plot(x_test[inds,:], y_test[inds], 'k.', alpha=.2, markersize=10,
            fillstyle='none', label=u'Observations')

if (y_upper1 is not None) and (y_lower1 is not None):

    plt.fill(np.concatenate([x_test[inds], x_test[inds][::-1]]),
                np.concatenate([y_u_[inds], y_l_[inds][::-1]]),
                alpha=.3, fc=shade_color, ec='None',
                label = 'CRR' + ' prediction interval')
    #plt.ylim([-2.5, 7])
    plt.xlabel('$X$')
    plt.ylabel('$Y$')
    plt.legend(loc='upper left')
    plt.title(title)
    if save_figures and (filename is not None):
        plt.savefig(filename, bbox_inches='tight', dpi=300)

if pred1 is not None:
    if pred_.ndim == 2:
        plt.plot(x_test[inds,:], pred_[inds,0], 'k', lw=2, alpha=0.9,
                    label=u'Predicted low and high quantiles')
        plt.plot(x_test[inds,:], pred_[inds,1], 'k', lw=2, alpha=0.9)
    else:
        plt.plot(x_test[inds,:], pred_[inds], 'k--', lw=2, alpha=0.9,
                    label=u'Predicted value')
        

plt.show()

#DeCRR

freq2, y_upper2, y_lower2, quantile2, pred2= DeCRR(x_train, y_train, x_test, y_test, 0.1, 5)
x_ = x_test[:]
y_ = y_test[:]
if y_upper2 is not None:
    y_u_ = np.array(y_upper2[:]).astype(y_.dtype)
if y_lower2 is not None:
    y_l_ = np.array(y_lower2[:]).astype(y_.dtype)
if pred2 is not None:
    pred_ = np.array(pred2[:])

shade_color = 'tomato'
method_name="DeCRR:"
filename="illustration_decrr.png"
save_figures=save_figures
title="DeCRR"

fig = plt.figure()
inds = np.argsort(np.squeeze(x_))
plt.plot(x_[inds,:], y_[inds], 'k.', alpha=.2, markersize=10,
            fillstyle='none', label=u'Observations')

if (y_upper2 is not None) and (y_lower2 is not None):

    plt.fill(np.concatenate([x_[inds], x_[inds][::-1]]),
                np.concatenate([y_u_[inds], y_l_[inds][::-1]]),
                alpha=.3, fc=shade_color, ec='None',
                label = method_name + ' prediction interval')
    #plt.ylim([-2.5, 7])
    plt.xlabel('$X$')
    plt.ylabel('$Y$')
    plt.legend(loc='upper left')
    plt.title(title)
    if save_figures and (filename is not None):
        plt.savefig(filename, bbox_inches='tight', dpi=300)
if pred2 is not None:
        if pred_.ndim == 2:
            plt.plot(x_[inds,:], pred_[inds,0], 'k', lw=2, alpha=0.9,
                     label=u'Predicted low and high quantiles')
            plt.plot(x_[inds,:], pred_[inds,1], 'k', lw=2, alpha=0.9)
        else:
            plt.plot(x_[inds,:], pred_[inds], 'k--', lw=2, alpha=0.9,
                     label=u'Predicted value')
    
plt.show()

#oracle RR
eps_l = calloQuantile(eps_test, alpha/2)
eps_u = calupQuantile(eps_test, 1-alpha/2)
true_coefficient = 3.0
intercept = 2.0
print(pred3)
y_upper3 = pred3 + eps_u
y_lower3 = pred3 + eps_l

x_ = x_test[:]
y_ = y_test[:]
if y_upper3 is not None:
    y_u_ = np.array(y_upper3[:]).astype(y_test.dtype)
if y_lower3 is not None:
    y_l_ = np.array(y_lower3[:]).astype(y_test.dtype)
if pred3 is not None:
    pred_ = np.array(pred3[:])


shade_color = 'grey'
method_name="Oracle:"
filename="illustration_oracle_rr.png"
save_figures=save_figures
title="Oracle"

fig = plt.figure()
inds = np.argsort(np.squeeze(x_))
plt.plot(x_[inds,:], y_[inds], 'k.', alpha=.2, markersize=10,
            fillstyle='none', label=u'Observations')

if (y_upper3 is not None) and (y_lower3 is not None):

    plt.fill(np.concatenate([x_[inds], x_[inds][::-1]]),
                np.concatenate([y_u_[inds], y_l_[inds][::-1]]),
                alpha=.3, fc=shade_color, ec='None',
                label = method_name + ' prediction interval')
    #plt.ylim([-2.5, 7])
    plt.xlabel('$X$')
    plt.ylabel('$Y$')
    plt.legend(loc='upper left')
    plt.title(title)
    if save_figures and (filename is not None):
        plt.savefig(filename, bbox_inches='tight', dpi=300)
if pred3 is not None:
        if pred_.ndim == 2:
            plt.plot(x_[inds,:], pred_[inds,0], 'k', lw=2, alpha=0.9,
                     label=u'Predicted low and high quantiles')
            plt.plot(x_[inds,:], pred_[inds,1], 'k', lw=2, alpha=0.9)
        else:
            plt.plot(x_[inds,:], pred_[inds], 'k--', lw=2, alpha=0.9,
                     label=u'Predicted value')
    
plt.show()