import numpy as np
import os
import argparse

from util.datasets import build_dataset, CustomDataset, CustomDataset_selected, CustomDataset_ImgList, build_transform

from sklearn.metrics import roc_curve, auc
from sklearn import metrics


def get_fpr_idx(id_scores, ood_scores, threshold_type=['tpr', 0.95]):
    # 合并分数
    scores = np.concatenate([id_scores, ood_scores])

    # 生成标签，假设ID样本的标签为0，OOD样本的标签为1
    labels = np.concatenate([np.zeros_like(id_scores), np.ones_like(ood_scores)])

    # 计算ROC曲线
    fpr, tpr, thresholds = roc_curve(labels, scores)

    auroc = metrics.auc(fpr, tpr)
    print(f'AUROC: {auroc}')

    recall = threshold_type[1]
    # 查找TPR最接近0.95的索引
    if threshold_type[0] == 'tpr':
        idx = (np.abs(tpr - recall)).argmin()
    
    # 查找FPR最接近0.10的索引
    # idx = (np.abs(fpr - recall)).argmin()
    else:
        idx = (np.abs(fpr - recall)).argmin()

    # 输出对应的FPR
    # print(f'FPR@95: {fpr[idx]}')
    print(f'FPR@{recall}: {fpr[idx]}')

    # 获取对应的阈值
    threshold = thresholds[idx]

    # 找出预测分数>阈值的样本，这些就是我们认为的OOD样本
    # if threshold_type[0] == 'tpr':
    ood_samples = [i for i, score in enumerate(ood_scores) if score > threshold]
    # else:
    #     ood_samples = [i for i, score in enumerate(ood_scores) if score < threshold]
    # print(threshold, len(ood_samples))
    # print(ood_scores.shape)
    return fpr[idx], ood_samples

def compare(idscores1, oodscores1, threshold_type=['tpr', 0.95]):#, idscores2, oodscores2):
    fpr1, oodsample1 = get_fpr_idx(idscores1, oodscores1, threshold_type)
    # fpr2, oodsample2 = get_fpr_idx(idscores2, oodscores2) 
    # print(fpr1)
    # print(fpr_recall(-idscores1, -oodscores1, 0.95))
    # print(fpr_recall(-idscores2, -oodscores2, 0.95))
    # t1 = oodscores1 < np.percentile(idscores1, 95)
    # # t2 = oodscores2 < np.percentile(idscores2, 95)
    # print(np.sum(t1) / t1.shape[0])#, np.sum(t2) / t2.shape[0])
    return oodsample1

def get_ood(path, id, ood, fpr=False, recall=0.05):
    data1 = np.load('{}/{}.npz'.format(path, id))
    data2 = np.load('{}/{}.npz'.format(path, ood))
    threshold_type = ['tpr', recall]
    if fpr:
        threshold_type=['fpr', recall]
    ood_sample1 = compare(-data1['conf'], -data2['conf'], threshold_type)
    return ood_sample1, data1, data2

def get_ood_path_list(ood, sample_ids):
    def get_root(file_path):
        for i in ['ct', 'xraybone', 'covid']:
            if i in file_path:
                return '/data4/jiangy/OpenOOD-main/data/images_medical'
        for j in ['ssb_hard', 'ninco', 'inaturalist', 'openimageo']:
            if j in file_path:
                return '/data4/jiangy/OpenOOD-main/data/images_largescale'
        return '/data4/jiangy/OpenOOD-main/data/images_classic'
    file_path = '/data4/jiangy/OpenOOD-main/data/benchmark_imglist/{}/test_{}.txt'.format(id, ood)
    dataset_ood = CustomDataset_ImgList(root=get_root(file_path), file_path=file_path)
    path_list = np.array(dataset_ood.image_paths)[sample_ids]
    return path_list.tolist()

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--id', default='cifar10', type=str)
    parser.add_argument('--pre', default=0.008, type=float)
    parser.add_argument('--ft', default=0.06, type=float)
    args = parser.parse_args()
    return args

args = get_args()
recall_pre = args.pre #0.008
recall_ft = args.ft #0.06
id = args.id
ood_list = ['cifar100', 'tin', 'mnist', 'svhn', 'texture', 'places365']
if id == 'cifar100':
    ood_list = ['cifar10', 'tin', 'mnist', 'svhn', 'texture', 'places365']
elif id == 'mnist':
    ood_list = ['notmnist', 'fashionmnist', 'texture', 'cifar10', 'tin', 'places365']
elif id == 'covid':
    ood_list = ['cifar10', 'ct', 'xraybone', 'mnist', 'texture']
elif id == 'imagenet200':
    ood_list = ['ssb_hard', 'ninco', 'inaturalist', 'textures', 'openimage_o']
all_path = []
all_ood1 = []
all_ood_pre = []

npz_path = '/data4/jiangy/OpenOOD-main/results/vit-b-16_eval/{}_vit-b-16/s2/test_ood_ood_mds_20_ood/scores'.format(id)
npz_path_pre = '/data4/jiangy/OpenOOD-main/results/vit-b-16_eval_pre/{}_vit-b-16/s2/test_ood_ood_mds_1_ood/scores'.format(id)
for dataset in ood_list:
    ood_sample_pre, data_test_pre, data_ood_pre = get_ood(npz_path_pre, id, dataset, fpr=True, recall=recall_pre)
    ood_sample_ft, data_test_ft, data_ood_ft = get_ood(npz_path, id, dataset, fpr=True, recall=recall_ft)
    diff = set(ood_sample_pre) - (set(ood_sample_ft))
    print('id: {} -- ood: {}'.format(id, dataset))
    print(len(ood_sample_pre), len(ood_sample_ft), len(diff))
    if len(diff) == 0:
        print(f'*** Not found OOD samples in {dataset}')
        continue
    path_list = get_ood_path_list(dataset, np.array(list(diff)))
    print('imglist: ', len(path_list), path_list[-1])
    all_path = all_path + (path_list)
    all_ood1.append(data_ood_ft['conf'][np.array(list(diff))])

print('all')
compare(-data_test_ft['conf'], -np.concatenate(all_ood1))

print(len(all_path))
print('./hiddenList/recall_{:.1f}_{:d}'.format((recall_pre * 100), int(recall_ft * 100)))
os.makedirs('./hiddenList/recall_{:.1f}_{:d}'.format((recall_pre * 100), int(recall_ft * 100)), exist_ok=True)

if id == 'covid':
    path_classic = []
    path_medical = []
    for item in all_path:
        if 'ct' in item or 'xraybone' in item:
            path_medical.append('images_medical/' + item)
        else:
            path_classic.append('images_classic/' + item)
    path_combine = path_classic + path_medical
    with open('./hiddenList/recall_{:.1f}_{:d}/{}_hidden_e20_new.txt'.format((recall_pre * 100), int(recall_ft * 100), id), 'w') as f:
        for item in path_combine:
            f.write(item + ' -1' + '\n')
elif id == 'imagenet200':
    path_classic = []
    path_largescale = []
    for item in all_path:
        if 'texture' not in item:
            path_largescale.append('images_largescale/' + item)
        else:
            path_classic.append('images_classic/' + item)
    path_combine = path_classic + path_largescale
    with open('./hiddenList/recall_{:.1f}_{:d}/{}_hidden_e20_new.txt'.format((recall_pre * 100), int(recall_ft * 100), id), 'w') as f:
        for item in path_combine:
            f.write(item + ' -1' + '\n')
else:
    with open('./hiddenList/recall_{:.1f}_{:d}/{}_hidden_e20_new.txt'.format((recall_pre * 100), int(recall_ft * 100), id), 'w') as f:
        for item in all_path:
            f.write(item + ' -1' + '\n')
    # with open('./hiddenList/recall_{:.1f}_{:d}/{}_classic_hidden_e10_new.txt'.format((recall_pre * 100), int(recall_ft * 100), id), 'w') as f:
    #     for item in path_classic:
    #         f.write(item + ' -1' + '\n')
    # with open('./hiddenList/recall_{:.1f}_{:d}/{}_medical_hidden_e10_new.txt'.format((recall_pre * 100), int(recall_ft * 100), id), 'w') as f:
    #     for item in path_medical:
    #         f.write(item + ' -1' + '\n')
# all_path_np = np.concatenate(all_path)
# print(all_path_np.shape)


