import numpy as np 
from sklearn.model_selection import train_test_split


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def marginalCoverageSize(S,targets):
    covered = 0
    size = 0
    for i in range(targets.shape[0]):
        if (targets[i].item() in S[i]):
            covered += 1
        size = size + S[i].shape[0]
    return float(covered)/targets.shape[0], size/targets.shape[0]


def conditionalCoverageSize(S,targets):
    """ account size of the right prediction sets   """
    num = 0
    size = 0
    for i in range(targets.shape[0]):
        if (targets[i].item() in S[i]):
            size = size + S[i].shape[0]
            num +=1
    return size/num





###################################################
# Worst-slice coverage rate and size
###################################################
def wsc(X, y, S, delta=0.1, M=1000, random_state=2020):
    rng = np.random.default_rng(random_state)

    def wsc_v(X, y, S, delta, v):
        n = len(y)
        cover = np.array([y[i] in S[i] for i in range(n)])
        z = np.dot(X,v)
        # Compute mass
        z_order = np.argsort(z)
        z_sorted = z[z_order]
        cover_ordered = cover[z_order]
        ai_max = int(np.round((1.0-delta)*n))
        ai_best = 0
        bi_best = n
        cover_min = 1
        for ai in np.arange(0, ai_max):
            bi_min = np.minimum(ai+int(np.round(delta*n)),n)
            coverage = np.cumsum(cover_ordered[ai:n]) / np.arange(1,n-ai+1)
            coverage[np.arange(0,bi_min-ai)]=1
            bi_star = ai+np.argmin(coverage)
            cover_star = coverage[bi_star-ai]
            if cover_star < cover_min:
                ai_best = ai
                bi_best = bi_star
                cover_min = cover_star
        return cover_min, z_sorted[ai_best], z_sorted[bi_best]

    def sample_sphere(n, p):
        v = rng.normal(size=(p, n))
        v /= np.linalg.norm(v, axis=0)
        return v.T

    V = sample_sphere(M, p=X.shape[1])
    wsc_list = [[]] * M
    a_list = [[]] * M
    b_list = [[]] * M
    
    for m in range(M):
        wsc_list[m], a_list[m], b_list[m] = wsc_v(X, y, S, delta, V[m])
        
    idx_star = np.argmin(np.array(wsc_list))
    a_star = a_list[idx_star]
    b_star = b_list[idx_star]
    v_star = V[idx_star]
    wsc_star = wsc_list[idx_star]
    return wsc_star, v_star, a_star, b_star

def wsc_unbiased(X, y, S, delta=0.1, M=1000, test_size=0.75, random_state=2020, verbose=False):
    def wsc_vab(X, y, S, v, a, b):
        n = len(y)
        cover = np.array([y[i] in S[i] for i in range(n)])
        size_mean = np.array([len(S[i]) for i in range(n)])
        z = np.dot(X,v)
        idx = np.where((z>=a)*(z<=b))
        coverage = np.mean(cover[idx])
        size_mean = np.mean(size_mean[idx])
        
        return coverage,size_mean

    X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(X, y, S, test_size=test_size,
                                                                         random_state=random_state)
    # Find adversarial parameters
    wsc_star, v_star, a_star, b_star = wsc(X_train, y_train, S_train, delta=delta, M=M, random_state=random_state)
    # Estimate coverage
    coverage,size = wsc_vab(X_test, y_test, S_test, v_star, a_star, b_star)
    return coverage,size

def cal_MacroCoverAndCoverViolation(correct_array,targets,num_classes,alpha):
    """
    compute the macro cover and the cover violation
    输入是预测集合，样本的标签,样本的标签种类
    返回是：  Macro Cover  (double), cover violation (double)
    """
    # 样本总数量
    correct_array = np.array(correct_array)
    targets = np.array(targets)
    MC = 0
    violation_num=0
    for k in range(num_classes):
        if len(correct_array[targets ==k ])==0:
            continue
        cv = np.mean(correct_array[targets ==k ])
        if cv<1-alpha:
            violation_num+=1
        MC+= 1/num_classes *cv
    
    return MC,round(violation_num/num_classes,2)


def cal_MacroIneff(size_array,targets,num_classes):
    size_array = np.array(size_array)
    targets = np.array(targets)
    MI = 0
    for k in range(num_classes):
        if len(size_array[targets ==k ])==0:
            continue
        cv = np.mean(size_array[targets ==k ])
        
        MI+= 1/num_classes *cv
    
    return MI
    
    
        
    
    
    
    
