# # Collection of Util function for CSPO

# import numpy as np
# import torch


# def get_q_hat(predmodel, score_function, x, y, alpha):
#     """return q_hat"""
#     score_list = []
#     for i in range(x.shape[0]):
#         score = score_function(predmodel, x[i,:], y[i,:])
#         score_list.append(score)
#     q_hat = np.quantile(np.array(score_list), 1-alpha)
#     return q_hat.item()

# def check_coverage(predmodel, score_function, x, y, q_hat):
#     """ Check if our conformal prediction set satisfies the desired coverage rate"""
#     coverge_check_list = []
#     x = torch.from_numpy(x).float()
#     y = torch.from_numpy(y).float()
#     for i in range(x.shape[0]):
#         if score_function(predmodel,x[i,:],y[i,:]) <= q_hat:
#             coverge_check_list.append(1)
#         else:
#             coverge_check_list.append(0)
#     coverage = sum(coverge_check_list)/len(coverge_check_list)
#     return coverage

def truncate_data_list(predmodel_list, score_function, x,y,q_hat):
    """ Truncate training data """
    index_list = []
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    for i in range(x.shape[0]):
        true_flag = True
        for j in range(len(predmodel_list)):
            if score_function(predmodel_list[j],x[i,:],y[i,j,:]) > q_hat[j]:
                true_flag = False
        if true_flag:   
            index_list.append(i)
    return index_list


# def mse_loss(predmodel, x, y):
#     """ Compute MSE Loss """
#     predmodel.eval()
#     with torch.no_grad():
#         pred = predmodel(x)
#         loss = torch.mean((pred - y)**2)
#     predmodel.train()
#     return loss.item()



# Collection of Util function for CSPO

import numpy as np
import torch


def get_q_hat(predmodel, score_function, x, y, alpha):
    """return q_hat"""
    score_list = []
    for i in range(x.shape[0]):
        score = score_function(predmodel, x[i,:], y[i,:])
        score_list.append(score)
    q_hat = np.quantile(np.array(score_list), 1-alpha)
    return q_hat.item()

def check_coverage(predmodel, score_function, x, y, q_hat):
    """ Check if our conformal prediction set satisfies the desired coverage rate"""
    coverge_check_list = []
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    for i in range(x.shape[0]):
        if score_function(predmodel,x[i,:],y[i,:]) <= q_hat:
            coverge_check_list.append(1)
        else:
            coverge_check_list.append(0)
    coverage = sum(coverge_check_list)/len(coverge_check_list)
    return coverage

def truncate_data(predmodel, score_function, x,y,q_hat):
    """ Truncate training data """
    index_list = []
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    for i in range(x.shape[0]):
        if score_function(predmodel,x[i,:],y[i,:]) <= q_hat:
            index_list.append(i)
    return index_list

def mse_loss(predmodel, x, y):
    """ Compute MSE Loss """
    predmodel.eval()
    # x = torch.from_numpy(x)
    # weights = torch.from_numpy(weights)
    # x = x.float()
    # weights = weights.float()
    pred = predmodel(x)
    predmodel.train()
    return torch.mean((pred - y)**2).item()