import numpy as np
from scipy.stats import norm


def path_to_str(path):
    return '-'.join([str(p) for p in path])


def str_to_path(str_path):
    tokens = str_path.split('-')
    return [int(tok) for tok in tokens]


def sample_C(path, loss_samples, index_last_retrain, now):
    index_model = index_last_retrain
    if np.any(path):
         index_model = int(np.argwhere(np.array(path)==1)[-1])+now
    time_index = len(path)+now-1
    return loss_samples[index_model][time_index]


def get_decision_ci(previous_path , retrain_cost, loss_samples, C_samples_dict, end_T,index_last_retrain, now):
    if len(previous_path) == end_T: # hit the end, you just return -1
        C_samples = sample_C(previous_path, loss_samples,index_last_retrain, now)
        return -1, C_samples
        
    
    else:
        path_if_retrain = previous_path+[1]
        path_if_not_retrain = previous_path+[0]
        
        
        
        dec_r, C_samples_retrain_past = get_decision_ci(path_if_retrain , retrain_cost, loss_samples, C_samples_dict, end_T,index_last_retrain, now)
        
        dec_n, C_samples_no_retrains_past = get_decision_ci(path_if_not_retrain , retrain_cost, loss_samples, C_samples_dict, end_T,index_last_retrain, now)
           
        PE_r = sample_C(path_if_retrain, loss_samples,index_last_retrain, now)
        PE_keep = sample_C(path_if_not_retrain, loss_samples,index_last_retrain, now)
            
         
        C_retrain = C_samples_retrain_past+PE_r+retrain_cost
        C_keep = C_samples_no_retrains_past + PE_keep
        
        quantile_retrain = np.percentile(C_retrain, 95)
        quantile_no_retrain = np.percentile(C_keep, 95)
        
        if quantile_retrain + retrain_cost < quantile_no_retrain: # this mean retraining is better
            return 1, C_retrain
        else:
            return 0, C_keep
    
        

if __name__ == "__main__":
    # Parameters
    T = 3

    loss_mean = {}
    loss_var = {}
    loss_samples = {}
    for i in range(T+1):
        loss_mean[i] = {}
        loss_var[i] = {}
        loss_samples[i] = {}
        for j in range(i, T+1):
            mean = 0.2 - (i+1)/T
            var = 0.02 * j
            loss_mean[i][j] = mean
            loss_var[i][j] = var
            loss_samples[i][j] = np.random.normal(mean, var, 100)

    a, C_samples = get_decision_ci([], retrain_cost=0, loss_samples=loss_samples, C_samples_dict={}, end_T=T)
    print(a)
