# -*- coding: utf-8 -*-
from signatures import *


##############################################################################################
# Experiment: Brownian motion (Figure A6a in the appendix)


np.random.seed(0)
random.seed(0)

for e in range(30):
    random_experi_num = 100
    sample_num = 100
    sig_degree = 4
    d = 2
    sig_num = calculate_sig_num(sig_degree, d)
    path_length = 100
    T = 1
    
    true_factor_num_range = [2,4,6,8]
    rho_range = np.arange(-0.9, 1, 0.1)
    correct_rate_I_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    correct_rate_S_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    avg_max_precision_I_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    avg_max_precision_S_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    avg_max_recall_I_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    avg_max_recall_S_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    avg_max_F1_I_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    avg_max_F1_S_list = pd.DataFrame(columns=rho_range, index=true_factor_num_range)
    
    for true_factor_num in true_factor_num_range:
        select_result_I_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        select_result_S_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        max_precision_I_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        max_precision_S_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        max_recall_I_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        max_recall_S_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        max_F1_I_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        max_F1_S_list = pd.DataFrame(columns=rho_range, index=range(random_experi_num))
        
        for rho in rho_range:
        
            rho1 = rho
            
            for exp in range(random_experi_num):
                print(rho, exp)
                signature_list_I = pd.DataFrame(np.zeros((sample_num, sig_num)))
                signature_list_S = pd.DataFrame(np.zeros((sample_num, sig_num)))
                
                
                
                beta_location = random.sample(range(sig_num), true_factor_num)
                beta_values = np.random.randn(true_factor_num)
                
                
                for i in range(sample_num):
                    # BM
                    Z1_list = np.cumsum(np.random.randn(path_length)) * np.sqrt(1/path_length)
                    Z2_list = np.cumsum(np.random.randn(path_length)) * np.sqrt(1/path_length)
                    W1_list = Z1_list
                    W2_list = rho1 * Z1_list + np.sqrt(1-rho1**2) * Z2_list
                    
                    X_all = np.array([W1_list, W2_list]).T
                    
                    signature_I = calculate_signature_to_K(X_all, sig_degree, method="Ito")
                    signature_list_I.loc[i,:] = signature_I
                    signature_S = calculate_signature_to_K(X_all, sig_degree, method="S")
                    signature_list_S.loc[i,:] = signature_S
                
                # standardize
                X_list_I = np.array(signature_list_I)
                X_list_I = X_list_I / np.sqrt(np.sum(X_list_I**2, axis=0)) 
                X_list_S = np.array(signature_list_S)
                X_list_S = X_list_S / np.sqrt(np.sum(X_list_S**2, axis=0)) 
                
                y_list_I =  X_list_I[:,beta_location] @ beta_values + np.random.randn(sample_num)*np.sqrt(0.0001)
                y_list_S =  X_list_S[:,beta_location] @ beta_values + np.random.randn(sample_num)*np.sqrt(0.0001)
                
                _, _, coefs_I = linear_model.lars_path(X_list_I, y_list_I, method="lasso", verbose=True)
                _, _, coefs_S = linear_model.lars_path(X_list_S, y_list_S, method="lasso", verbose=True)
                
                
                select_result_I = check_Lasso_select_result(coefs_I, beta_location, beta_values)
                select_result_S = check_Lasso_select_result(coefs_S, beta_location, beta_values)
            
                select_result_I_list.loc[exp, rho] = select_result_I
                select_result_S_list.loc[exp, rho] = select_result_S
                
                max_precision_I, max_recall_I, max_F1_I = check_Lasso_confusion_matrix(coefs_I, beta_location, beta_values)
                max_precision_S, max_recall_S, max_F1_S = check_Lasso_confusion_matrix(coefs_S, beta_location, beta_values)
            
                max_precision_I_list.loc[exp, rho] = max_precision_I
                max_precision_S_list.loc[exp, rho] = max_precision_S
                max_recall_I_list.loc[exp, rho] = max_recall_I
                max_recall_S_list.loc[exp, rho] = max_recall_S
                max_F1_I_list.loc[exp, rho] = max_F1_I
                max_F1_S_list.loc[exp, rho] = max_F1_S
        
        correct_rate_I = np.sum(select_result_I_list, axis=0) / random_experi_num
        correct_rate_S = np.sum(select_result_S_list, axis=0) / random_experi_num
        
        correct_rate_I_list.loc[true_factor_num, :] = correct_rate_I
        correct_rate_S_list.loc[true_factor_num, :] = correct_rate_S
        
        max_precision_I = np.sum(max_precision_I_list, axis=0) / random_experi_num
        max_precision_S = np.sum(max_precision_S_list, axis=0) / random_experi_num
        
        avg_max_precision_I_list.loc[true_factor_num, :] = max_precision_I
        avg_max_precision_S_list.loc[true_factor_num, :] = max_precision_S
        
        max_recall_I = np.sum(max_recall_I_list, axis=0) / random_experi_num
        max_recall_S = np.sum(max_recall_S_list, axis=0) / random_experi_num
        
        avg_max_recall_I_list.loc[true_factor_num, :] = max_recall_I
        avg_max_recall_S_list.loc[true_factor_num, :] = max_recall_S
        
        max_F1_I = np.sum(max_F1_I_list, axis=0) / random_experi_num
        max_F1_S = np.sum(max_F1_S_list, axis=0) / random_experi_num
        
        avg_max_F1_I_list.loc[true_factor_num, :] = max_F1_I
        avg_max_F1_S_list.loc[true_factor_num, :] = max_F1_S
        
    correct_rate_I_list.to_csv('Experiment_BM_correct_rate_I_'+str(e)+'.csv')
    correct_rate_S_list.to_csv('Experiment_BM_correct_rate_S_'+str(e)+'.csv')
    avg_max_precision_I_list.to_csv('Experiment_BM_avg_max_precision_I_'+str(e)+'.csv')
    avg_max_precision_S_list.to_csv('Experiment_BM_avg_max_precision_S_'+str(e)+'.csv')
    avg_max_recall_I_list.to_csv('Experiment_BM_avg_max_recall_I_'+str(e)+'.csv')
    avg_max_recall_S_list.to_csv('Experiment_BM_avg_max_recall_S_'+str(e)+'.csv')
    avg_max_F1_I_list.to_csv('Experiment_BM_avg_max_F1_I_'+str(e)+'.csv')
    avg_max_F1_S_list.to_csv('Experiment_BM_avg_max_F1_S_'+str(e)+'.csv')
    




data_I = pd.read_csv('Experiment_BM_correct_rate_I_0.csv', index_col=0)
data_S = pd.read_csv('Experiment_BM_correct_rate_S_0.csv', index_col=0)


data_I_all_2 = pd.DataFrame(np.zeros((30, np.shape(data_I)[1])), columns=data_I.columns.astype(float))
data_S_all_2 = pd.DataFrame(np.zeros((30, np.shape(data_S)[1])), columns=data_S.columns.astype(float))
data_I_all_4 = pd.DataFrame(np.zeros((30, np.shape(data_I)[1])), columns=data_I.columns.astype(float))
data_S_all_4 = pd.DataFrame(np.zeros((30, np.shape(data_S)[1])), columns=data_S.columns.astype(float))
data_I_all_6 = pd.DataFrame(np.zeros((30, np.shape(data_I)[1])), columns=data_I.columns.astype(float))
data_S_all_6 = pd.DataFrame(np.zeros((30, np.shape(data_S)[1])), columns=data_S.columns.astype(float))


for exp in range(30):

    data_I = pd.read_csv('Experiment_BM_correct_rate_I_'+str(exp)+'.csv', index_col=0)
    data_S = pd.read_csv('Experiment_BM_correct_rate_S_'+str(exp)+'.csv', index_col=0)
    
    
    data_I_all_2.iloc[exp,:] = data_I.loc[2,:].values
    data_S_all_2.iloc[exp,:] = data_S.loc[2,:].values
    data_I_all_4.iloc[exp,:] = data_I.loc[4,:].values
    data_S_all_4.iloc[exp,:] = data_S.loc[4,:].values
    data_I_all_6.iloc[exp,:] = data_I.loc[6,:].values
    data_S_all_6.iloc[exp,:] = data_S.loc[6,:].values
    
    
    

plt.figure(figsize=(9.5,5))

linecolor = ['tab:blue', 'tab:orange', 'tab:green']


plt.plot(data_I_all_2.mean(), marker="o", markersize=10, markeredgewidth=2, linewidth=2, markerfacecolor='none',linestyle="-",color=linecolor[0],label='$q=2$')
plt.fill_between(data_I_all_2.columns, data_I_all_2.mean()-1.96*data_I_all_2.std()/np.sqrt(30), data_I_all_2.mean()+1.96*data_I_all_2.std()/np.sqrt(30), color=linecolor[0],alpha=0.3)

plt.plot(data_S_all_2.mean(), marker="^", markersize=10, markeredgewidth=2, linewidth=2, markerfacecolor='none',linestyle="--",color=linecolor[0])
plt.fill_between(data_S_all_2.columns, data_S_all_2.mean()-1.96*data_I_all_2.std()/np.sqrt(30), data_S_all_2.mean()+1.96*data_S_all_2.std()/np.sqrt(30), color=linecolor[0],alpha=0.3)

plt.plot(data_I_all_4.mean(), marker="o", markersize=10, markeredgewidth=2, linewidth=2, markerfacecolor='none',linestyle="-",color=linecolor[1],label='$q=4$')
plt.fill_between(data_I_all_4.columns, data_I_all_4.mean()-1.96*data_I_all_4.std()/np.sqrt(30), data_I_all_4.mean()+1.96*data_I_all_4.std()/np.sqrt(30), color=linecolor[1],alpha=0.3)

plt.plot(data_S_all_4.mean(), marker="^", markersize=10, markeredgewidth=2, linewidth=2, markerfacecolor='none',linestyle="--",color=linecolor[1])
plt.fill_between(data_S_all_4.columns, data_S_all_4.mean()-1.96*data_I_all_4.std()/np.sqrt(30), data_S_all_4.mean()+1.96*data_S_all_4.std()/np.sqrt(30), color=linecolor[1],alpha=0.3)

plt.plot(data_I_all_6.mean(), marker="o", markersize=10, markeredgewidth=2, linewidth=2, markerfacecolor='none',linestyle="-",color=linecolor[2],label='$q=6$')
plt.fill_between(data_I_all_6.columns, data_I_all_6.mean()-1.96*data_I_all_6.std()/np.sqrt(30), data_I_all_6.mean()+1.96*data_I_all_6.std()/np.sqrt(30), color=linecolor[2],alpha=0.3)

plt.plot(data_S_all_6.mean(), marker="^", markersize=10, markeredgewidth=2, linewidth=2, markerfacecolor='none',linestyle="--",color=linecolor[2])
plt.fill_between(data_S_all_6.columns, data_S_all_6.mean()-1.96*data_I_all_6.std()/np.sqrt(30), data_S_all_6.mean()+1.96*data_S_all_6.std()/np.sqrt(30), color=linecolor[2],alpha=0.3)



plt.legend(fontsize=20, loc="upper right")
plt.xlabel("$\\rho$", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.ylim(-0.05, 1.05)
plt.xlim(-1, 1)
plt.ylabel("Consistency Rate", fontsize=20)
plt.grid()
plt.savefig('Figure_A6a.pdf', bbox_inches = 'tight' , dpi=150, pad_inches = 0.05)