import copy
import os
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.stats as sts
import numpy as np
from tqdm import tqdm
class Proj_Model(nn.Module):
    def __init__(self, num_features, num_classes, linear_add=False):
        super(Proj_Model, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes

        self.linear_k = nn.Sequential(
            nn.Linear(num_features, 1),
            nn.Softplus()
        )
        self.linear_kw = nn.Sequential(
            nn.Linear(num_classes, 1),
            nn.Softplus()
        )

        self.add_linear = linear_add
        if linear_add:
            self.linear_add = nn.Sequential(nn.Linear(num_features, num_features))  # for imagenet, cifar-10
        # self.linear_add = nn.Sequential(nn.Linear(num_features, num_features*2),
        #                            nn.ReLU(),
        #                            nn.Linear(num_features*2, num_features)) # for places-10
        self.linear = nn.Linear(num_features, num_classes)
        self.return_all = True

    def reparameterize(self, lbd, kappa, force_sample=False):
        '''
            weibull reparameterization: z = lbd * (- ln(1 - u)) ^ (1/kappa), u ~ uniform(0,1)
            z: node-community affiliation.
            lbd: scale parameter, kappa: shape parameter
        '''

        def log_max(input, SMALL=1e-10):
            device = input.device
            input_ = torch.max(input, torch.tensor([SMALL]).to(device))
            return torch.log(input_)

        # print(force_sample)
        if self.training or force_sample:
            u = torch.rand_like(lbd)
            z = lbd * (- log_max(1 - u)).pow(1 / kappa)
        else:
            z = lbd * torch.exp(torch.lgamma(1 + kappa.pow(-1)))
        return z

    def forward(self, X, factor_z=0., factor_w=0.0, force_sample=False, return_all =True):
        if self.add_linear:
            z = self.linear_add(X)
        else:
            z = X

        # non-negative z
        gelu_z = F.relu(z - factor_z)
        z_out = gelu_z #  - gelu_z.data + F.relu(z - factor_z).data

        # reparameterize z
        k = self.linear_k(z_out) #torch.ones_like(z_out, requires_grad=False, device=z_out.device) * 10
        # print(f'k_min {k.min()}\t k_max {k.max()}')
        weibull_lambda = z_out / torch.exp(torch.lgamma(1 + 1 / k))
        pre_out = self.reparameterize(weibull_lambda, k, force_sample)  # N * H

        # reparameterize w
        gelu_z_w = F.relu(self.linear.weight.transpose(1, 0) - factor_w)
        z_out_w = gelu_z_w #  - gelu_z_w.data + F.relu(self.linear.weight.transpose(1, 0) - factor_w).data
        # z_out_w = F.softmax(self.linear.weight.transpose(1, 0), dim=-1)
        k_w = self.linear_kw(z_out_w) #torch.ones_like(z_out_w, requires_grad=False, device=z_out.device) * 100000
        # print(f'k_w min {k_w.min()}\n k_w max {k_w.max()}')
        weibull_lambda_w = z_out_w / torch.exp(torch.lgamma(1 + 1 / k_w))
        pre_out_w = self.reparameterize(weibull_lambda_w, k_w, force_sample)  # H * C

        # pre_out = F.normalize(pre_out, p=2, dim=-1)
        # pre_out_w = F.normalize(pre_out_w, p=2, dim=-1)
        if pre_out.dim() == 1:
            pre_out = pre_out.unsqueeze(0)
        out = torch.mm(pre_out, pre_out_w) + F.relu(self.linear.bias - factor_w)

        if return_all: #False: # return_all:
            return out, z_out, weibull_lambda, k, weibull_lambda_w, k_w
        else:
            return out


class Concat_Proj_Model(nn.Module):
    def __init__(self, num_features, num_classes, proj_model_old, linear_add=False):
        super(Concat_Proj_Model, self).__init__()
        self.proj_model = Proj_Model(num_features, num_classes, linear_add=linear_add)
        self.proj_model_old = proj_model_old
        for p in self.proj_model_old.parameters():
            p.data = p.data.detach()
            p.requires_grad = False

    def forward(self, X, factor_z=0., factor_w=0.0, force_sample=False, return_all=True):
        out_old = self.proj_model_old(X, factor_z, factor_w, return_all=False, force_sample=force_sample)
        out, z_out, weibull_lambda, k, weibull_lambda_w, k_w = self.proj_model(X, factor_z, factor_w, force_sample=force_sample)

        total_out = out_old + out
        if return_all:#return_all:
            return total_out, z_out, weibull_lambda, k, weibull_lambda_w, k_w
        else:
            return total_out


def KL_GamWei(Gam_shape, Gam_scale, Wei_shape_res, Wei_scale):
    def log_max(input, SMALL=1e-10):
        device = input.device
        input_ = torch.max(input, torch.tensor([SMALL]).to(device))
        return torch.log(input_)

    eulergamma = torch.tensor(0.5772, dtype=torch.float32, requires_grad=False)
    part1 = Gam_shape * log_max(Wei_scale) - eulergamma.to(Wei_scale.device) * Gam_shape * Wei_shape_res + log_max(Wei_shape_res)
    part2 = - Gam_scale * Wei_scale * torch.exp(torch.lgamma(1 + Wei_shape_res))
    part3 = eulergamma.to(Wei_scale.device) + 1 + Gam_shape * log_max(Gam_scale) - torch.lgamma(Gam_shape)
    KL = part1 + part2 + part3
    return -KL.sum(1).mean()


def batch_uncertain_emb_fc(model_logits,
                    num_classes, target, emb, fc,
                    accurate_pred, testresult):
    ## uncertainty estimation
    def two_sample_test_batch(logits, sample_num):
        prob = torch.softmax(logits, 1)
        probmean = torch.mean(prob, 2)
        values, indices = torch.topk(probmean, 2, dim=1)
        aa = logits.gather(1, indices[:, 0].unsqueeze(1).unsqueeze(1).repeat(1, 1, sample_num))
        bb = logits.gather(1, indices[:, 1].unsqueeze(1).unsqueeze(1).repeat(1, 1, sample_num))
        # if True:
        pvalue = sts.ttest_rel(aa.detach().cpu(), bb.detach().cpu(), axis=2).pvalue
        # else:
        # pvalue = np.zeros(shape=(aa.shape[0], aa.shape[1]))
        # for i in range(pvalue.shape[0]):
        #     pvalue[i] = sts.wilcoxon(aa.detach().cpu()[i, 0, :], bb.detach().cpu()[i, 0, :]).pvalue
        return pvalue

    sample_num = 20
    device = target.device
    logits_ii = np.zeros([emb.size(0), num_classes, sample_num])
    logits_greedy = np.zeros([emb.size(0), num_classes])
    logits_greedy[:, :] = model_logits.cpu().data.numpy()


    # # sample v1
    # for iii in range(sample_num):  # todo: uncertainty estimation
    #     # important step !!!!!!
    #     tmp_model = copy.deepcopy(model)
    #     model_fc = tmp_model.fc
    #     tmp_model.fc = nn.Identity()
    #     emb = tmp_model(inp)
    #     output = model_fc(emb, force_sample=True)
    #     model_logits = output[0] if (type(output) is tuple) else output
    #     logits_ii[:, :, iii] = model_logits.cpu().data.numpy()
    #     del tmp_model

    # sample v2
    model_fc = nn.DataParallel(fc)
    for iii in range(sample_num):  # todo: uncertainty estimation
        # important step !!!!!!
        output = model_fc(emb, force_sample=True)
        model_logits = output[0] if (type(output) is tuple) else output
        logits_ii[:, :, iii] = model_logits.cpu().data.numpy()

    mean_logits = F.log_softmax(torch.mean(F.softmax(torch.from_numpy(logits_ii).to(device), dim=1), 2), 1)

    logits_tsam = torch.from_numpy(logits_ii).to(target.device)
    # prob = F.softmax(logits_tsam, 1)
    # ave_prob = torch.mean(prob, 2)
    # prediction = torch.argmax(ave_prob, 1).to(device)
    prediction = torch.argmax(torch.from_numpy(logits_greedy), 1).to(device)
    accurate_pred_i = (prediction == target).type_as(logits_tsam)
    accurate_pred = torch.cat([accurate_pred, accurate_pred_i], 0)
    testresult_i = torch.from_numpy(two_sample_test_batch(logits_tsam, sample_num)).type_as(logits_tsam)
    testresult = torch.cat([testresult, testresult_i], 0)
    # print(f'testresult shape {testresult.shape}')
    # print(testresult)
    return testresult, mean_logits, accurate_pred

def batch_uncertain(model_logits,
                    num_classes, target, inp, model,
                    accurate_pred, testresult):
    ## uncertainty estimation
    # num_classes=9
    def two_sample_test_batch(logits, sample_num):
        prob = torch.softmax(logits, 1)
        probmean = torch.mean(prob, 2)
        values, indices = torch.topk(probmean, 2, dim=1)
        aa = logits.gather(1, indices[:, 0].unsqueeze(1).unsqueeze(1).repeat(1, 1, sample_num))
        bb = logits.gather(1, indices[:, 1].unsqueeze(1).unsqueeze(1).repeat(1, 1, sample_num))
        # if True:
        pvalue = sts.ttest_rel(aa.detach().cpu(), bb.detach().cpu(), axis=2).pvalue
        # else:
        # pvalue = np.zeros(shape=(aa.shape[0], aa.shape[1]))
        # for i in range(pvalue.shape[0]):
        #     pvalue[i] = sts.wilcoxon(aa.detach().cpu()[i, 0, :], bb.detach().cpu()[i, 0, :]).pvalue
        return pvalue

    sample_num = 20
    device = target.device
    logits_ii = np.zeros([inp.size(0), num_classes, sample_num])
    logits_greedy = np.zeros([inp.size(0), num_classes])
    logits_greedy[:, :] = model_logits.cpu().data.numpy()


    # # sample v1
    # for iii in range(sample_num):  # todo: uncertainty estimation
    #     # important step !!!!!!
    #     tmp_model = copy.deepcopy(model)
    #     model_fc = tmp_model.fc
    #     tmp_model.fc = nn.Identity()
    #     emb = tmp_model(inp)
    #     output = model_fc(emb, force_sample=True)
    #     model_logits = output[0] if (type(output) is tuple) else output
    #     logits_ii[:, :, iii] = model_logits.cpu().data.numpy()
    #     del tmp_model

    # sample v2
    tmp_model = copy.deepcopy(model.module if isinstance(model, nn.DataParallel) else model)
    model_fc = tmp_model.fc
    tmp_model.fc = nn.Identity()

    tmp_model = nn.DataParallel(tmp_model)
    model_fc = nn.DataParallel(model_fc)
    emb = tmp_model(inp)
    for iii in range(sample_num):  # todo: uncertainty estimation
        # important step !!!!!!
        output = model_fc(emb, force_sample=True)
        model_logits = output[0] if (type(output) is tuple) else output
        logits_ii[:, :, iii] = model_logits.cpu().data.numpy()
    del tmp_model

    mean_logits = F.log_softmax(torch.mean(F.softmax(torch.from_numpy(logits_ii).to(device), dim=1), 2), 1)

    logits_tsam = torch.from_numpy(logits_ii).to(target.device)
    # prob = F.softmax(logits_tsam, 1)
    # ave_prob = torch.mean(prob, 2)
    # prediction = torch.argmax(ave_prob, 1).to(device)
    prediction = torch.argmax(torch.from_numpy(logits_greedy), 1).to(device)
    accurate_pred_i = (prediction == target).type_as(logits_tsam)
    accurate_pred = torch.cat([accurate_pred, accurate_pred_i], 0)
    testresult_i = torch.from_numpy(two_sample_test_batch(logits_tsam, sample_num)).type_as(logits_tsam)
    testresult = torch.cat([testresult, testresult_i], 0)
    # print(f'testresult shape {testresult.shape}')
    # print(testresult)
    return testresult, mean_logits, accurate_pred


def get_uncertain_indices(testresult, mean_logits, accurate_pred):
    uncertain = (testresult > 0.05).type_as(mean_logits)
    # acc = accurate_pred.squeeze() == 1
    # inacc = ~acc
    acc = torch.where(accurate_pred == 1)[0]
    inacc = torch.where(accurate_pred == 0)[0]

    ac = (accurate_pred * (1 - uncertain.squeeze())).sum()
    iu = ((1 - accurate_pred) * uncertain.squeeze()).sum()

    base_aic = (ac + iu).item() / accurate_pred.size(0) * 100
    accurate_pred = accurate_pred.squeeze().view(-1)
    uncertain = uncertain.squeeze().view(-1)

    acc_and_uncertain_mask = (accurate_pred==1) & (uncertain==1)
    inacc_and_certain_mask = (accurate_pred==0) & (uncertain==0)
    acc_and_certain_mask = (accurate_pred==1) & (uncertain==0)#(~accurate_pred.bool()) & (uncertain.bool())
    acc_uncertain_indices = torch.where(acc_and_uncertain_mask)[0]
    inacc_certain_indices = torch.where(inacc_and_certain_mask)[0]
    acc_and_certain_indices = torch.where(inacc_and_certain_mask)[0]
    print(f"acc_uncertain_indices: {acc_uncertain_indices.shape}, inacc_certain_indices: {inacc_certain_indices.shape}")
    print(base_aic)
    acc_uncertain_indices = torch.tensor(acc_uncertain_indices, dtype = torch.long)
    inacc_certain_indices = torch.tensor(inacc_certain_indices, dtype = torch.long)
    return base_aic, acc_uncertain_indices.cpu(), inacc_certain_indices.cpu(), acc_and_certain_indices.cpu()

def uncertain_cal(testresult, mean_logits, accurate_pred):
    uncertain = (testresult > 0.01).type_as(mean_logits)
    up_1 = uncertain.mean() * 100
    # ucpred_1 = ((uncertain == noise_mask_conca).type_as(mean_logits)).mean() * 100
    ac_1 = (accurate_pred * (1 - uncertain.squeeze())).sum()
    iu_1 = ((1 - accurate_pred) * uncertain.squeeze()).sum()

    ac_prob_1 = ac_1 / (1 - uncertain.squeeze()).sum() * 100
    iu_prob_1 = iu_1 / (1 - accurate_pred).sum() * 100

    uncertain = (testresult > 0.05).type_as(mean_logits)
    up_2 = uncertain.mean() * 100
    # ucpred_2 = (uncertain == noise_mask_conca).type_as(mean_logits).mean() * 100
    ac_2 = (accurate_pred * (1 - uncertain.squeeze())).sum()
    iu_2 = ((1 - accurate_pred) * uncertain.squeeze()).sum()

    ac_prob_2 = ac_2 / (1 - uncertain.squeeze()).sum() * 100
    iu_prob_2 = iu_2 / (1 - accurate_pred).sum() * 100

    uncertain = (testresult > 0.1).type_as(mean_logits)
    up_3 = uncertain.mean() * 100
    # ucpred_3 = (uncertain == noise_mask_conca).type_as(mean_logits).mean() * 100
    ac_3 = (accurate_pred * (1 - uncertain.squeeze())).sum()
    iu_3 = ((1 - accurate_pred) * uncertain.squeeze()).sum()

    ac_prob_3 = ac_3 / (1 - uncertain.squeeze()).sum() * 100
    iu_prob_3 = iu_3 / (1 - accurate_pred).sum() * 100

    base_aic_1 = (ac_1 + iu_1) / accurate_pred.size(0) * 100 # todo: PavPU
    base_aic_2 = (ac_2 + iu_2) / accurate_pred.size(0) * 100
    base_aic_3 = (ac_3 + iu_3) / accurate_pred.size(0) * 100
    base_aic = [base_aic_1, base_aic_2, base_aic_3]  #PavPu

    ac_prob = [ac_prob_1, ac_prob_2, ac_prob_3]
    iu_prob = [iu_prob_1, iu_prob_2, iu_prob_3]
    # ucpred = [ucpred_1, ucpred_2, ucpred_3]

    # uncertainty proportion
    up = [up_1,up_2,up_3]
    return base_aic

def uncertain_estimate_basemodel(model, dataloader):
    '''
    get the uncertaints per sample
    '''

    accurate_pred = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)

    probs, inp, labels = [], [], []
    for batch in tqdm(dataloader):
        x, y, g, p = batch
        x, y, p = x.cuda(), y.cuda(), p.cuda()

        logits = model(x)
        probs.append(logits.detach().cpu())
        labels.append(y.detach().cpu())
        inp.append(x.detach().cpu())

    probs = torch.cat(probs, dim=0)
    labels = torch.cat(labels, dim=0)
    inp = torch.cat(inp, dim=0)
    # ece = cal.get_ece(F.softmax(probs, dim=-1), labels)
    # print(f'ECE {ece}')
    ece = None
    testresult = F.softmax(probs, dim=-1).max(dim=-1).values.unsqueeze(dim=-1)
    pavpus = [0, 0, 0]
    print(f'pavpus: {pavpus[0]:.4f}\t {pavpus[1]:.4f}\t {pavpus[2]:.4f}\t')

    return testresult, pavpus, ece, inp, labels, probs



def uncertain_estimate_emb(fc, embs, y):
    '''
    get the uncertaints per sample
    '''

    accurate_pred = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)
    testresult = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)
    n_classes = fc.linear.weight.shape[0]
    probs, inp, labels = [], [], []

    logits, z_out, weibull_lambda, k, weibull_lambda_w, k_w = fc(x)

    testresult, mean_logits, accurate_pred = batch_uncertain_emb_fc(logits, n_classes,
                                                             y, embs, fc,
                                                             accurate_pred, testresult)

    probs.append(logits.detach().cpu())

    ece = None
    probs = torch.cat(probs, dim=0)

    # ece = cal.get_ece(F.softmax(probs, dim=-1), labels)
    # print(f'ECE {ece}')

    pavpus = uncertain_cal(testresult, mean_logits, accurate_pred)
    print(f'pavpus: {pavpus[0]:.4f}\t {pavpus[1]:.4f}\t {pavpus[2]:.4f}\t')

    base_aic, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices = get_uncertain_indices(testresult, mean_logits,
                                                                                   accurate_pred)
    return testresult, probs, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices

def uncertain_estimate(model, dataloader):
    '''
    get the uncertaints per sample
    '''

    accurate_pred = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)
    testresult = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)
    n_classes = 2#dataloader.dataset.n_classes
    probs, inp, labels = [], [], []
    for batch in tqdm(dataloader):
        x, y, g, p = batch
        x, y, p = x.cuda(), y.cuda(), p.cuda()

        logits, z_out, weibull_lambda, k, weibull_lambda_w, k_w = model(x)
        testresult, mean_logits, accurate_pred = batch_uncertain(logits, n_classes,
                                                                 y, x, model,
                                                                 accurate_pred, testresult)
        probs.append(logits.detach().cpu())
        labels.append(y.detach().cpu())
        inp.append(x.detach().cpu())
    ece = None
    probs = torch.cat(probs, dim=0)
    labels = torch.cat(labels, dim=0)
    inp = torch.cat(inp, dim=0)
    # ece = cal.get_ece(F.softmax(probs, dim=-1), labels)
    # print(f'ECE {ece}')

    pavpus = uncertain_cal(testresult, mean_logits, accurate_pred)
    print(f'pavpus: {pavpus[0]:.4f}\t {pavpus[1]:.4f}\t {pavpus[2]:.4f}\t')

    base_aic, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices = get_uncertain_indices(testresult, mean_logits,
                                                                                   accurate_pred)
    return testresult, pavpus, ece, inp, labels, probs, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices

from torch.utils.data import Dataset
import torch
import numpy as np

class UncertaintyFilteredDataset(Dataset):
    def __init__(self, x, y, group_array, confounder_array, p_array, n_places, global_indices):
        self.x = x
        self.y_array = y

        self.p_array = p_array
        self.confounder_array = confounder_array
        self.n_classes = np.unique(self.y_array).size
        self.n_places = n_places
        self.group_array = group_array
        self.n_groups = self.n_classes * self.n_places
        self.global_indices = global_indices

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y_array[idx], self.group_array[idx], self.confounder_array[idx]

def subsample_upweight(model, dataloader, subsample_propo=0.5, upweight_fac=20):
    '''
    从 dataloader 中选出：
    - 每类中不确定性最低的 k 个样本
    - 所有样本中不确定性最高的 k 个样本
    并将它们封装为 UncertaintyFilteredDataset 实例返回
    '''

    sample_uncertain, pavpus, ece, inp, labels, probs, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices\
        = uncertain_estimate(model, dataloader)

    n_classes = dataloader.dataset.n_classes

    # device-safe
    sample_uncertain = sample_uncertain.cpu()
    labels = labels.cpu()
    inp = inp.cpu()

    probs_softmax = torch.softmax(probs, dim=1)  # shape: (N, C)
    preds = torch.argmax(probs_softmax, dim=1)  # shape: (N,)
    correct = (preds == labels)  # shape: (N,)

    correct_indices = torch.where(correct == True)[0]
    subsample_indices = np.random.permutation(correct_indices)[: int(correct.sum() * subsample_propo)]


    error_indices = list(torch.where(correct == False)[0]) * upweight_fac

    # 获取 group、metadata 等信息
    group_array = dataloader.dataset.group_array
    confounder_array = dataloader.dataset.confounder_array
    p_array = dataloader.dataset.p_array
    n_places = dataloader.dataset.n_places
    # metadata_df = dataloader.dataset.metadata_df

    # 构建两个 Dataset
    def build_dataset(indices):
        return UncertaintyFilteredDataset(
            x=inp[indices],
            y=labels[indices],
            group_array=group_array[indices],
            confounder_array = confounder_array[indices],
            p_array = p_array[indices],
            n_places=n_places,
            global_indices=torch.tensor(indices)
        )

    used_indices = torch.tensor(list(subsample_indices) + error_indices, dtype=torch.long)
    subupweight_dataset = build_dataset(used_indices)
    a_v, a_cnt = np.unique(group_array[used_indices], return_counts=True)
    for v, cnt in zip(a_v, a_cnt):
        print(f'group {v}: {cnt}')
    return subupweight_dataset

def get_uncertain_datasets(model, dataloader, topk_ratio=0.5, held_out_idx=None, return_indices_only=False):
    '''
    从 dataloader 中选出：
    - 每类中不确定性最低的 k 个样本
    - 所有样本中不确定性最高的 k 个样本
    并将它们封装为 UncertaintyFilteredDataset 实例返回
    '''
    assert 0. <= topk_ratio <= 1.0
    N = len(dataloader.dataset)
    topk = int(N * topk_ratio)
    sample_uncertain, pavpus, ece, inp, labels, probs, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices = uncertain_estimate(model, dataloader)

    n_classes = dataloader.dataset.n_classes

    # device-safe
    sample_uncertain = sample_uncertain.cpu()
    labels = labels.cpu()
    inp = inp.cpu()

    probs_softmax = torch.softmax(probs, dim=1)  # (N, C)
    preds = torch.argmax(probs_softmax, dim=1)   # (N,)
    correct = (preds == labels)

    # 构建 valid mask: 排除 held_out_idx
    all_indices = torch.arange(N)
    if held_out_idx is not None and len(held_out_idx) > 0:
        held_out_mask = torch.tensor(
            [i not in held_out_idx for i in range(N)],
            dtype=torch.bool
        )
    else:
        held_out_mask = torch.ones(N, dtype=torch.bool)

    valid_mask = held_out_mask # (~correct) & held_out_mask
    valid_indices = all_indices[valid_mask]
    topk = int(len(valid_indices) * topk_ratio)

    if topk == 0:
        raise ValueError(f"Top-k=0，误分类样本太少，请降低 topk_ratio 或检查数据。")

    # 统计高不确定性误分类比例
    valid_indices = all_indices[valid_mask]
    kk = min(2000, len(valid_indices))
    topk_indices_all = torch.argsort(sample_uncertain[valid_indices].view(-1), descending=True)[:kk]
    topk_indices = valid_indices[topk_indices_all]
    misclassified_ratio = (~correct[topk_indices]).float().mean().item()
    print(f"Top-{kk} 不确定性样本中，错误分类比例: {misclassified_ratio:.3f}")

    group_array = dataloader.dataset.group_array
    confounder_array = dataloader.dataset.confounder_array
    p_array = dataloader.dataset.p_array
    n_places = dataloader.dataset.n_places

    # ---------------------- 每类中选样本 ----------------------
    tmp_k = topk // n_classes
    selected_indices_low = []
    selected_indices_top = []

    for c in range(n_classes):
        class_mask = (labels == c)
        class_indices = all_indices[class_mask & valid_mask]  # 仅在非 held-out 样本中选
        if len(class_indices) == 0:
            continue
        class_uncertainties = sample_uncertain[class_indices].view(-1)
        sorted_class_indices = class_indices[torch.argsort(class_uncertainties)]
        selected_indices_low.extend(sorted_class_indices[:tmp_k].tolist())
        selected_indices_top.extend(sorted_class_indices[-tmp_k:].tolist())

    selected_indices = torch.tensor(selected_indices_low + selected_indices_top, dtype=torch.long)

    # -------------------- 全局高低不确定性样本 ---------------------
    sorted_uncert_all = torch.argsort(sample_uncertain.view(-1))
    sorted_uncert_valid = sorted_uncert_all[valid_mask[sorted_uncert_all]]  # 排除 held_out

    low_uncertain_indices = sorted_uncert_valid[:topk]
    top_uncertain_indices = sorted_uncert_valid[-topk:]

    # 剩余索引
    remain_mask = torch.ones(N, dtype=torch.bool)
    remain_mask[top_uncertain_indices] = False
    remaining_indices = all_indices[remain_mask]

    top_threshold_indices = torch.where(sample_uncertain[valid_indices].view(-1) >= 0.05)[0]

    # cnt = N // 50
    # for i in range(cnt):
    #     if i == 0:
    #         aa = sorted_uncert_all[-50 * (i+1): ]
    #     else:
    #         aa = sorted_uncert_all[-50 * (i + 1): -50 * i]
    #     a_v, a_cnt = np.unique(group_array[aa], return_counts=True)
    #     print(f'{-50 * (i + 1)} ~ {-50 * i}: {a_v}\t {a_cnt}')
    # ------------------------ 构建 Dataset ------------------------
    def build_dataset(indices):
        return UncertaintyFilteredDataset(
            x=inp[indices],
            y=labels[indices],
            group_array=group_array[indices],
            confounder_array=confounder_array[indices],
            p_array=p_array[indices],
            n_places=n_places,
            global_indices=torch.tensor(indices) # 保留全局索引
        )

    au_ic_indices = torch.cat([inacc_certain_indices, acc_uncertain_indices], dim=0)
    au_i_indices = torch.cat([valid_indices, acc_uncertain_indices], dim=0)
    ac_au_indices = torch.cat([acc_and_certain_indices, top_uncertain_indices], dim=0)
    # You can also try truncating the dataset to control the number of samples.
    if au_ic_indices.shape[0] > 50:
        indices = torch.randperm(au_ic_indices.shape[0])[: 50]
        au_ic_indices = au_ic_indices[indices]
    if top_uncertain_indices.shape[0] > 50:
        indices = torch.randperm(top_uncertain_indices.shape[0])[: 50]
        top_uncertain_indices = top_uncertain_indices[indices]
    if return_indices_only:
        return selected_indices, top_uncertain_indices, low_uncertain_indices, remaining_indices, top_threshold_indices, acc_uncertain_indices, inacc_certain_indices, au_ic_indices, ac_au_indices
    else:

        class_uncertainty_dataset = build_dataset(selected_indices)
        high_uncertainty_dataset = build_dataset(top_uncertain_indices)
        remain_uncertainty_dataset = build_dataset(remaining_indices)
        threshold_uncertainty_dataset = build_dataset(top_threshold_indices)
        acc_uncertain_dataset= build_dataset(acc_uncertain_indices)
        inacc_certain_dataset = build_dataset(inacc_certain_indices)
        au_ic_dataset = build_dataset(au_ic_indices)
        au_i_dataset = build_dataset(au_i_indices)
        ac_au= build_dataset(ac_au_indices)
        return class_uncertainty_dataset, high_uncertainty_dataset, remain_uncertainty_dataset, threshold_uncertainty_dataset, acc_uncertain_dataset, inacc_certain_dataset, au_ic_dataset



def get_uncertain_datasets_embedding(fc, embs, y, g, p, topk_ratio=0.5, held_out_idx=None, return_indices_only=False):
    '''
    从 dataloader 中选出：
    - 每类中不确定性最低的 k 个样本
    - 所有样本中不确定性最高的 k 个样本
    并将它们封装为 UncertaintyFilteredDataset 实例返回
    '''
    assert 0. <= topk_ratio <= 1.0
    N = embs.shape[0]
    topk = int(N * topk_ratio)

    sample_uncertain, probs, acc_uncertain_indices, inacc_certain_indices, acc_and_certain_indices = uncertain_estimate_emb(fc, embs)

    n_classes = fc.linear.weight.shape[-1]

    # device-safe
    sample_uncertain = sample_uncertain.cpu()
    labels = labels.cpu()
    inp = inp.cpu()

    probs_softmax = torch.softmax(probs, dim=1)  # (N, C)
    preds = torch.argmax(probs_softmax, dim=1)   # (N,)
    correct = (preds == labels)

    # 构建 valid mask: 排除 held_out_idx
    all_indices = torch.arange(N)
    if held_out_idx is not None and len(held_out_idx) > 0:
        held_out_mask = torch.tensor(
            [i not in held_out_idx for i in range(N)],
            dtype=torch.bool
        )
    else:
        held_out_mask = torch.ones(N, dtype=torch.bool)

    valid_mask = (~correct) & held_out_mask
    valid_indices = all_indices[valid_mask]
    topk = int(len(valid_indices) * topk_ratio)

    if topk == 0:
        raise ValueError(f"Top-k=0，误分类样本太少，请降低 topk_ratio 或检查数据。")

    # 统计高不确定性误分类比例
    valid_indices = all_indices[valid_mask]
    kk = min(2000, len(valid_indices))
    topk_indices_all = torch.argsort(sample_uncertain[valid_indices].view(-1), descending=True)[:kk]
    topk_indices = valid_indices[topk_indices_all]
    misclassified_ratio = (~correct[topk_indices]).float().mean().item()
    print(f"Top-{kk} 不确定性样本中，错误分类比例: {misclassified_ratio:.3f}")

    group_array = g
    p_array = p

    # ---------------------- 每类中选样本 ----------------------
    tmp_k = topk // n_classes
    selected_indices_low = []
    selected_indices_top = []

    for c in range(n_classes):
        class_mask = (labels == c)
        class_indices = all_indices[class_mask & valid_mask]  # 仅在非 held-out 样本中选
        if len(class_indices) == 0:
            continue
        class_uncertainties = sample_uncertain[class_indices].view(-1)
        sorted_class_indices = class_indices[torch.argsort(class_uncertainties)]
        selected_indices_low.extend(sorted_class_indices[:tmp_k].tolist())
        selected_indices_top.extend(sorted_class_indices[-tmp_k:].tolist())

    selected_indices = torch.tensor(selected_indices_low + selected_indices_top, dtype=torch.long)

    # -------------------- 全局高低不确定性样本 ---------------------
    sorted_uncert_all = torch.argsort(sample_uncertain.view(-1))
    sorted_uncert_valid = sorted_uncert_all[valid_mask[sorted_uncert_all]]  # 排除 held_out

    low_uncertain_indices = sorted_uncert_valid[:topk]
    top_uncertain_indices = sorted_uncert_valid[-topk:]

    # 剩余索引
    remain_mask = torch.ones(N, dtype=torch.bool)
    remain_mask[top_uncertain_indices] = False
    remaining_indices = all_indices[remain_mask]

    top_threshold_indices = torch.where(sample_uncertain[valid_indices].view(-1) >= 0.7)[0]

    # cnt = N // 50
    # for i in range(cnt):
    #     if i == 0:
    #         aa = sorted_uncert_all[-50 * (i+1): ]
    #     else:
    #         aa = sorted_uncert_all[-50 * (i + 1): -50 * i]
    #     a_v, a_cnt = np.unique(group_array[aa], return_counts=True)
    #     print(f'{-50 * (i + 1)} ~ {-50 * i}: {a_v}\t {a_cnt}')
    # ------------------------ 构建 Dataset ------------------------

    au_ic_indices = torch.cat([inacc_certain_indices, acc_uncertain_indices], dim=0)
    au_i_indices = torch.cat([valid_indices, acc_uncertain_indices], dim=0)
    ac_au_indices = torch.cat([acc_and_certain_indices, top_uncertain_indices], dim=0)

    return selected_indices, top_uncertain_indices, remaining_indices, top_threshold_indices, acc_uncertain_indices, inacc_certain_indices, au_ic_indices, ac_au_indices
##xbw
def compute_pavpu_for_mask(testresult, mean_logits, accurate_pred, mask):
    # mask: boolean 或索引数组，表示属于某组的样本
    tr = testresult[mask]
    ml = mean_logits[mask]
    ap = accurate_pred[mask]
    # 假设你想要 threshold=0.05 上的 PAvPU:
    uncertain = (tr > 0.05).type_as(ml)
    ac = (ap * (1 - uncertain.squeeze())).sum().item()
    iu = ((1 - ap) * uncertain.squeeze()).sum().item()
    pavpu = (ac + iu) / mask.sum() * 100
    return pavpu

import torch.nn.functional as F

def get_topk_loss_samples(model, dataloader, loss_fn=F.cross_entropy, topk=5):
    model.eval()

    all_samples = []  # 每个元素是 (index, x, y, loss)
    global_index = 0  # 累计样本在整体数据集中的索引

    inp, labels = [], []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            x, y, g, p = batch
            x, y, p = x.cuda(), y.cuda(), p.cuda()

            logits, *_ = model(x)
            loss_per_sample = loss_fn(logits, y, reduction='none')

            batch_size = x.shape[0]
            for i in range(batch_size):
                all_samples.append((
                    global_index + i,  # 样本在整体数据中的索引
                    x[i].cpu(),  # 输入样本
                    y[i].cpu(),  # 标签
                    loss_per_sample[i].cpu().item()  # 对应 loss
                ))

            global_index += batch_size

            inp.append(x.detach().cpu())
            labels.append(y.detach().cpu())
    # 按照 loss 排序
    sorted_samples = sorted(all_samples, key=lambda t: t[3])

    lowest = sorted_samples[:topk]  # loss 最小的 top-k
    highest = sorted_samples[-topk:]  # loss 最大的 top-k

    inp = torch.cat(inp, dim=0)
    labels = torch.cat(labels, dim=0)
    return lowest, highest, inp, labels

def compare_loss_uncertain_samples(model, dataloader, topk=100):
    accurate_pred = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)
    testresult = torch.zeros([0], dtype=torch.float64).cuda(non_blocking=True)
    n_classes = dataloader.dataset.n_classes
    probs, inp, labels = [], [], []

    all_samples = []  # 每个元素是 (index, x, y, loss)
    global_index = 0
    for batch in tqdm(dataloader):
        x, y, g, p = batch
        x, y, p = x.cuda(), y.cuda(), p.cuda()

        logits, z_out, weibull_lambda, k, weibull_lambda_w, k_w = model(x)
        testresult, mean_logits, accurate_pred = batch_uncertain(logits, n_classes,
                                                                 y, x, model,
                                                                 accurate_pred, testresult)
        probs.append(logits.detach().cpu())
        labels.append(y.detach().cpu())
        inp.append(x.detach().cpu())

        loss_per_sample = F.cross_entropy(logits, y, reduction='none')

        batch_size = x.shape[0]
        for i in range(batch_size):
            all_samples.append((
                global_index + i,  # 样本在整体数据中的索引
                x[i].cpu(),  # 输入样本
                y[i].cpu(),  # 标签
                loss_per_sample[i].cpu().item()  # 对应 loss
            ))

        global_index += batch_size

    probs = torch.cat(probs, dim=0)
    labels = torch.cat(labels, dim=0)
    inp = torch.cat(inp, dim=0)
    ece = cal.get_ece(F.softmax(probs, dim=-1), labels)
    print(f'ECE {ece}')

    pavpus = uncertain_cal(testresult, mean_logits, accurate_pred)
    print(f'pavpus: {pavpus[0]:.4f}\t {pavpus[1]:.4f}\t {pavpus[2]:.4f}\t')

    # 所有样本中不确定性最高的 k 个样本
    top_uncertain_indices = torch.argsort(testresult.view(-1), descending=True)[:topk]

    # 所有样本中不确定性最低的 k 个样本
    low_uncertain_indices = torch.argsort(testresult.view(-1), descending=False)[:topk]

    # 按照 loss 排序
    sorted_samples = sorted(all_samples, key=lambda t: t[3])

    lowest_set = sorted_samples[:topk]  # loss 最小的 top-k
    highest_set = sorted_samples[-topk:]  # loss 最大的 top-k

    loss_low_indices = [item[0] for item in lowest_set]
    loss_high_indices = [item[0] for item in highest_set]

    def list_compare(a, b):
        set1, set2 = set(a), set(b)
        intersection = set1 & set2
        overlap_count = len(intersection)
        union_count = len(set1 | set2)
        ratio_in_list1 = overlap_count / len(set1) if len(set1) > 0 else 0
        ratio_in_list2 = overlap_count / len(set2) if len(set2) > 0 else 0
        jaccard = overlap_count / union_count if union_count > 0 else 0

        print(f'intersection_count {overlap_count}')
        print(f'overlap_ratio_in_list1 :{ratio_in_list1}')
        print(f'overlap_ratio_in_list2 :{ratio_in_list2}')
        print(f'jaccard similarity {jaccard}')

    print(f'high uncertain compare')
    list_compare(top_uncertain_indices.detach().cpu().tolist(), loss_high_indices)

    print('\n')
    print(f'low uncertain compare')
    list_compare(low_uncertain_indices.detach().cpu().tolist(), loss_low_indices)

    print('top-k')


def get_high_easy_datasets(model, dataloader, topk=100):
    lowest_set, highest_set, inp, labels = get_topk_loss_samples(model, dataloader, topk=topk)

    # 按照 loss 排序
    loss_low_indices = [item[0] for item in lowest_set]
    loss_high_indices = [item[0] for item in highest_set]
    # 剔除
    all_indices = torch.arange(inp.size(0))  # 所有样本的索引
    mask = torch.ones(inp.size(0), dtype=bool)
    mask[loss_high_indices] = False  # 将不确定性最高的 k 个样本对应位置置为 False
    del_high_indices = all_indices[mask]  # 保留其余的索引


    # 获取 group、metadata 等信息
    group_array = dataloader.dataset.group_array
    confounder_array = dataloader.dataset.confounder_array
    p_array = dataloader.dataset.p_array
    n_places = dataloader.dataset.n_places


    def build_dataset(indices):
        return UncertaintyFilteredDataset(
            x=inp[indices],
            y=labels[indices],
            group_array=group_array[indices],
            confounder_array=confounder_array[indices],
            p_array=p_array[indices],
            n_places=n_places
        )

    low_loss_dataset = build_dataset(loss_low_indices)
    high_loss_dataset = build_dataset(loss_high_indices)
    del_high_dataset = build_dataset(del_high_indices)


    return low_loss_dataset, high_loss_dataset, del_high_dataset

class MySubset(torch.utils.data.Subset):
    def __getattr__(self, name):
        return getattr(self.dataset, name)
