import numpy as np
import CRPS.CRPS as pscore

def RSE(pred, true):
    return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))


def CORR(pred, true):
    u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
    d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
    return (u / d).mean(-1)


def MAE(pred, true):
    return np.mean(np.abs(pred - true))


def MSE(pred, true):
    return np.mean((pred - true) ** 2)


def RMSE(pred, true):
    return np.sqrt(MSE(pred, true))


def MAPE(pred, true):
    return np.mean(np.abs((pred - true) / true))


def MSPE(pred, true):
    return np.mean(np.square((pred - true) / true))

def MSE_MASK(pred, true, mask):
    return np.sum(np.square((pred-true)*mask))/np.sum(mask)

def MAE_MASK(pred, true, mask):
    return np.sum(np.abs(pred-true)*mask)/np.sum(mask)

def metric_mask(pred, true, mask):
    mae = MAE_MASK(pred, true, mask)
    mse = MSE_MASK(pred, true, mask)
    return mae, mse



def metric(pred, true):
    mae = MAE(pred, true)
    mse = MSE(pred, true)
    rmse = RMSE(pred, true)
    mape = MAPE(pred, true)
    mspe = MSPE(pred, true)

    return mae, mse, rmse, mape, mspe



def compute_true_coverage_by_gen_QI(n_bins, dataset_object, all_true_y, all_generated_y):
    quantile_list = np.arange(n_bins + 1) * (100 / n_bins)
    y_pred_quantiles = np.percentile(all_generated_y.squeeze(), q=quantile_list, axis=1)
    y_true = all_true_y.T
    quantile_membership_array = ((y_true - y_pred_quantiles) > 0).astype(int)
    y_true_quantile_membership = quantile_membership_array.sum(axis=0)
    y_true_quantile_bin_count = np.array(
        [(y_true_quantile_membership == v).sum() for v in np.arange(n_bins + 2)])

    y_true_quantile_bin_count[1] += y_true_quantile_bin_count[0]
    y_true_quantile_bin_count[-2] += y_true_quantile_bin_count[-1]
    y_true_quantile_bin_count_ = y_true_quantile_bin_count[1:-1]
    y_true_ratio_by_bin = y_true_quantile_bin_count_ / dataset_object
    # assert np.abs(
    #     np.sum(y_true_ratio_by_bin) - 1) < 1e-10, "Sum of quantile coverage ratios shall be 1!"
    qice_coverage_ratio = np.absolute(np.ones(n_bins) / n_bins - y_true_ratio_by_bin).mean()
    return y_true_ratio_by_bin, qice_coverage_ratio, y_true



def CRPS_miss(preds_save, trues_save, masks_save):
    def ccc(id, pred, true):
        res_box = np.zeros(len(true))
        for i in range(len(true)):
            res = pscore(pred[i], true[i]).compute()
            res_box[i] = res[0]
    
        return res_box

    pred = preds_save.reshape(-1, preds_save.shape[-3], preds_save.shape[-2], preds_save.shape[-1])
    # ->[num_batch*batch_size, L, N]
    true = trues_save.reshape(-1, trues_save.shape[-2], trues_save.shape[-1])
    masks = masks_save.reshape(-1, trues_save.shape[-2], trues_save.shape[-1])
    

    all_res_get = []
    i=-1
    mask = masks[:, :, i]
    mask = mask.reshape(-1)

    p_in = pred[:, :, :, i]
    p_in = p_in.transpose(0, 2, 1)
    # ->[num_batch*batch_size*L, n_sample]
    p_in = p_in.reshape(-1, p_in.shape[-1])
    p_in = np.array(p_in[mask==1])

    t_in = true[:, :, i]
    t_in = t_in.reshape(-1)
    t_in = np.array(t_in[mask==1])

    # ->[num_batch*batch_size*L]
    print(p_in.shape, t_in.shape)
    all_res_get.append(ccc(i, p_in, t_in))
    CRPS_0 = np.mean(all_res_get, axis=0).mean()
    return CRPS_0
