import numpy as np

from src.ci_recursive import  get_decision_ci


def find_optimal_schedule(loss, thetas, retrain_cost, now, index_last_trained_model, loss_is_01, unc=False, loss_var=None):
    if unc:
        optimal_schedule = find_optimal_schedule_CI(
            loss, thetas, retrain_cost, now, index_last_trained_model, loss_is_01, loss_var=loss_var) 
    else:
        optimal_schedule = find_optimal_schedule_deterministic(
            loss, thetas, retrain_cost, now, index_last_trained_model)

    return optimal_schedule


def find_optimal_schedule_deterministic(loss, thetas, retrain_cost, now, index_last_trained_model):
    min_cost = np.inf
    optimal_schedule = None
    for theta in thetas:
        cost_dict = compute_cost(loss, theta, retrain_cost, now=now,
                                 index_last_trained_model=index_last_trained_model, unc=False)
        cost = cost_dict['C']
        if cost < min_cost:
            min_cost = cost
            optimal_schedule = theta
    return optimal_schedule


def alpha_beta_from_mean_var(mean, var):
    if mean < 0.01:
        mean = 0.01
    if mean > 0.99:
        mean = 0.99
    if var >= mean*(1-mean):
        var = mean*(1-mean)-0.001
    alpha = -mean*(mean**2-mean+var)/var
    beta = (mean-1) * (mean**2-mean+var)/var
    return alpha, beta

def compute_cost(loss, theta, retrain_cost, now, index_last_trained_model, unc, loss_var=None):
    cost_retraining = retrain_cost * np.sum(theta)
    C = cost_retraining
    var_loss = 0
    mean_loss = 0
    index_model = index_last_trained_model
    for t, theta_t in enumerate(theta):
        if theta_t == 1:
            index_model = t+now
        cost_perf_t = loss[index_model][t+now]
        if unc:
            if loss_var is not None:
                var_loss += loss_var[index_model][t+now]
        C += cost_perf_t
        mean_loss += cost_perf_t
    return {'C': C, 'mean_loss': mean_loss, 'var_loss': var_loss, 'cost_retraining': cost_retraining}

def find_optimal_schedule_CI(loss, thetas, retrain_cost, now, index_last_trained_model, loss_is_01, loss_var=None):
    loss_samples = {}
    num_samples = 5000
    for model_index, val in loss.items():
        loss_samples[model_index] = {}
        for t, mean in val.items():
            mean =loss[model_index][t]
            var = loss_var[model_index][t]
            if loss_is_01:
                alpha, beta = alpha_beta_from_mean_var(mean, var)
                PE_sample = np.random.beta(alpha, beta, num_samples) 
            else:
                 PE_sample = np.random.normal(mean, var, num_samples) 
            loss_samples[model_index][t] = PE_sample
    retrain, C = get_decision_ci([] , retrain_cost, loss_samples, {}, len(thetas[0]),index_last_trained_model, now)

    optimal_schedule = [retrain] + [0] * (len(thetas[0])-1)
    return optimal_schedule