# Collection of Util function for CSPO

import numpy as np
import torch


def get_q_hat(predmodel, score_function, x, y, alpha):
    """return q_hat"""
    score_list = []
    for i in range(x.shape[0]):
        score = score_function(predmodel, x[i,:], y[i,:])
        score_list.append(score)
    q_hat = np.quantile(np.array(score_list), 1-alpha)
    return q_hat.item()

def check_coverage(predmodel, score_function, x, y, q_hat):
    """ Check if our conformal prediction set satisfies the desired coverage rate"""
    coverge_check_list = []
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    for i in range(x.shape[0]):
        if score_function(predmodel,x[i,:],y[i,:]) <= q_hat:
            coverge_check_list.append(1)
        else:
            coverge_check_list.append(0)
    coverage = sum(coverge_check_list)/len(coverge_check_list)
    return coverage

def truncate_data(predmodel, score_function, x,y,q_hat):
    """ Truncate training data """
    index_list = []
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    for i in range(x.shape[0]):
        if score_function(predmodel,x[i,:],y[i,:]) <= q_hat:
            index_list.append(i)
    return index_list

def mse_loss(predmodel, x, y):
    """ Compute MSE Loss """
    predmodel.eval()
    # x = torch.from_numpy(x)
    # weights = torch.from_numpy(weights)
    # x = x.float()
    # weights = weights.float()
    pred = predmodel(x)
    predmodel.train()
    return torch.mean((pred - y)**2).item()