import os
import json
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.transforms import transforms
from tqdm import tqdm
from collections import defaultdict
from scipy.stats import hmean, gmean
from PIL import Image

from torchvision.datasets import (
    ImageNet, Caltech101, StanfordCars, CIFAR10, CIFAR100,
    DTD, EuroSAT, Food101, Flowers102, OxfordIIITPet, SUN397, FGVCAircraft,
)
from torchgeo.datasets import RESISC45
from torchvision.datasets.folder import ImageFolder
from datasets.imagenetv2 import ImagenetV2

from templates import *
from utils import *
from templates import ZEROSHOT_TEMPLATES, ZEROSHOT_CLASS_NAMES

DATA_ROOT = {
    "imagenet": "/home/youname/datasets/ImageNet",
    "imagenetv2": "/home/youname/datasets/clip-datasets/imagenetv2",
    "imagenet-a": "/home/youname/datasets/clip-datasets/imagenet-a",
    "imagenet-r": "/home/youname/datasets/clip-datasets/imagenet-r",
    "imagenet-sketch": "/home/youname/datasets/clip-datasets/imagenet-sketch",
    "resisc45": "/home/youname/datasets/clip-datasets/resisc45",
    "others": "/home/youname/datasets/clip-datasets",
}

DEVICE = torch.device("cuda:0")


def load_class_names(dataset_name):
    if dataset_name in ["imagenetv2", "imagenet-sketch"]:
        dataset_name = "imagenet"
    return ZEROSHOT_CLASS_NAMES[dataset_name]


def load_hand_crafted_templates(dataset_name):
    if dataset_name in ["imagenetv2", "imagenet-a", "imagenet-r", "imagenet-sketch"]:
        dataset_name = "imagenet"
    return ZEROSHOT_TEMPLATES[dataset_name]


def load_pool_templates():
    pool_templates = list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))
    return pool_templates


def load_all_templates():
    pool_templates = list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))
    extras_templates = EXTRA_TEMPLATES - set(pool_templates)
    all_templates = pool_templates
    return all_templates


def load_test_dataset(dataset_name, transform=None):
    if dataset_name == "imagenet":
        test_dataset = ImageNet(root=DATA_ROOT["imagenet"], split="val", transform=transform)

    if dataset_name == "imagenetv2":
        test_dataset = ImagenetV2(root=DATA_ROOT["imagenetv2"], transform=transform)

    if dataset_name == "imagenet-a":
        test_dataset = ImageFolder(root=DATA_ROOT["imagenet-a"], transform=transform)

    if dataset_name == "imagenet-r":
        test_dataset = ImageFolder(root=DATA_ROOT["imagenet-r"], transform=transform)

    if dataset_name == "imagenet-sketch":
        test_dataset = ImageFolder(root=DATA_ROOT["imagenet-sketch"], transform=transform)

    elif dataset_name == "caltech101":
        test_dataset = Caltech101(root=DATA_ROOT["others"], transform=transform, download=True)

    elif dataset_name == "cars196":
        test_dataset = StanfordCars(root=DATA_ROOT["others"], split="test", transform=transform, download=True)

    elif dataset_name == "cifar10":
        test_dataset = CIFAR10(root=DATA_ROOT["others"], train=False, transform=transform, download=True)

    elif dataset_name == "cifar100":
        test_dataset = CIFAR100(root=DATA_ROOT["others"], train=False, transform=transform, download=True)

    elif dataset_name == "dtd":
        test_dataset = DTD(root=DATA_ROOT["others"], split="test", transform=transform, download=True)

    elif dataset_name == "eurosat":
        test_dataset = EuroSAT(root=DATA_ROOT["others"], transform=transform)

    elif dataset_name == "food101":
        test_dataset = Food101(root=DATA_ROOT["others"], split="test", transform=transform, download=True)

    elif dataset_name == "oxford_flowers102":
        test_dataset = Flowers102(root=DATA_ROOT["others"], split="test", transform=transform, download=True)

    elif dataset_name == "oxford_iiit_pet":
        test_dataset = OxfordIIITPet(root=DATA_ROOT["others"], split="test", transform=transform, download=True)

    elif dataset_name == "resisc45":
        def transforms_resisc45(sample):
            image, label = sample["image"], sample["label"]
            image = transforms.ToPILImage()(image.to(torch.uint8))
            return transform(image), label
        test_dataset = RESISC45(root=DATA_ROOT["resisc45"], split="test", transforms=transforms_resisc45, download=True)

    elif dataset_name == "sun397":
        test_dataset = SUN397(root=DATA_ROOT["others"], transform=transform, download=True)
        # use test split from https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/sun397_tfds_te.txt
        sun397_te_split = np.loadtxt("./datasets/sun397_tfds_te.txt", dtype=str).tolist()
        test_dataset._image_files = [test_dataset._data_dir / f[1:] for f in sun397_te_split]
        test_dataset._labels = [
            test_dataset.class_to_idx["/".join(path.relative_to(test_dataset._data_dir).parts[1:-1])] for path in test_dataset._image_files
        ]

    elif dataset_name == "fgvc_aircraft":
        test_dataset = FGVCAircraft(root=DATA_ROOT["others"], split="test", transform=transform, download=False)

    return test_dataset


def load_test_image(dataset_name, index):
    transform = transforms.Lambda(lambda x: x.convert("RGB"))
    test_dataset = load_test_dataset(dataset_name, transform=transform)

    if dataset_name == 'imagenet':
        path, target = test_dataset.samples[index]
        print(path)

    image, target = test_dataset.__getitem__(index)
    classnames = load_class_names(dataset_name)
    print(classnames[target])

    return image

def load_test_embeds_and_labels(dataset_name, backbone_name):
    save_dir = os.path.join("./cache", dataset_name, backbone_name)
    embed_path = os.path.join(save_dir, "test_embeds.pt")
    label_path = os.path.join(save_dir, "test_labels.pt")

    if os.path.exists(embed_path) and os.path.exists(label_path):
        print("Loading existing checkpoints")
        test_embeds = torch.load(embed_path, map_location=DEVICE)
        test_labels = torch.load(label_path, map_location=DEVICE)
    else:
        print("No checkpoints found, extracting now")

        transform_test = transforms.Compose([
            transforms.Lambda(lambda x: x.convert("RGB")),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        print(f"Loading dataset {dataset_name}")

        test_dataset = load_test_dataset(dataset_name, transform_test)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=128, num_workers=8, pin_memory=True
        )

        print(f"Loading CLIP (backbone: {backbone_name})")
        clip_model = load_clip_to_cpu(backbone_name)
        clip_model.to(DEVICE)
        print("Turning off gradients in the model")
        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)

        print("Computing embeddings")
        test_embeds = []
        test_labels = []

        for image, label in tqdm(test_loader):
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            
            embed = clip_model.encode_image(image)

            test_embeds.append(embed)
            test_labels.append(label)
        
        test_embeds = torch.cat(test_embeds)
        test_labels = torch.cat(test_labels)

        os.makedirs(save_dir, exist_ok=True)
        torch.save(test_embeds, embed_path)
        torch.save(test_labels, label_path)

    return test_embeds, test_labels


def load_all_class_embeds(dataset_name, backbone_name, all_templates):
    save_dir = os.path.join("./cache", dataset_name, backbone_name, "class_embeds")
    all_class_embeds = []
    is_load = False

    for template in tqdm(all_templates):
        embed_path = os.path.join(save_dir, template)

        if os.path.exists(embed_path):
            class_embeds = torch.load(embed_path, map_location=DEVICE)
        else:
            if not is_load:
                print(f"Loading CLIP (backbone: {backbone_name})")
                clip_model = load_clip_to_cpu(backbone_name)
                clip_model.to(DEVICE)
                print("Turning off gradients in the model")
                for name, param in clip_model.named_parameters():
                    param.requires_grad_(False)
                
                is_load = True

            classnames = load_class_names(dataset_name)
            class_prompts = [template.format(c) for c in classnames]
            class_prompts = torch.cat([clip.tokenize(p) for p in class_prompts])
            class_prompts = class_prompts.to(DEVICE)
            class_embeds = clip_model.encode_text(class_prompts)

            os.makedirs(save_dir, exist_ok=True)
            torch.save(class_embeds, embed_path)
        
        all_class_embeds.append(class_embeds)
    all_class_embeds = torch.stack(all_class_embeds)
    return all_class_embeds


def cosine_similarity(a, b, score=None):
    a = F.normalize(a, dim=-1)
    b = F.normalize(b, dim=-1)

    if len(b.shape) == 3:
        if score is None:
            b = b.mean(dim=0)
        else:
            b = (score.reshape(-1, 1, 1) * b).mean(dim=0)

    return a @ b.t()

def predict(image_embeds, class_embeds):
    
    logit = cosine_similarity(image_embeds, class_embeds)
    return logit.argmax(dim=-1)


def evaluate(y_true, y_pred,p=True):
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()
        
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.cpu().numpy()

    acc = 100.0 * (y_true == y_pred).mean()

    per_class_res = defaultdict(list)

    for label, pred in zip(y_true, y_pred):
        matches = (label == pred).item()
        per_class_res[label].append(matches)

    labels = list(per_class_res.keys())
    labels.sort()

    cls_accs = []
    for label in labels:
        res = per_class_res[label]
        cls_accs.append(100.0 * sum(res) / len(res))
    cls_accs = np.array(cls_accs)
    
    worst1 = cls_accs.min()
    worst5 = np.sort(cls_accs)[:5].mean()
    worst10 = np.sort(cls_accs)[:10].mean()
    worst20 = np.sort(cls_accs)[:20].mean()
    worst50 = np.sort(cls_accs)[:50].mean()
    worst100 = np.sort(cls_accs)[:100].mean()

    macc = cls_accs.mean()
    hacc = hmean(cls_accs)
    gacc = gmean(cls_accs)
    if p is True:
        print("Accuracy: {}%".format(round(acc, 2)))
        print("Number of classes: {}".format(len(cls_accs)))
        print("Worst@1: {}% index: {}".format(round(worst1, 2),cls_accs.argmin()))
        print("Worst@5: {}% Acc: {}".format(round(worst5, 2),np.sort(cls_accs)[:5]))
        print("Worst@10: {}%".format(round(worst10, 2)))
        print("Worst@20: {}% Acc: {}".format(round(worst20, 2),np.sort(cls_accs)[:20]))
        print("Worst@50: {}%".format(round(worst50, 2)))
        print("Worst@100: {}%".format(round(worst100, 2)))
        print("Mean Class Acc: {}%".format(round(macc, 2)))
        print("Harmonic Mean: {}%".format(round(hacc, 2)))
        print("Geometric Mean: {}%".format(round(gacc, 2)))
        print("Worst Classes: ",np.argsort(cls_accs)[:10])
    else:
        return cls_accs
    return {"Accuracy":round(acc, 2),"Worst@1":round(worst1, 2),"Worst@5":round(worst5, 2),"Worst@10":round(worst10, 2)
            ,"Worst@20":round(worst20, 2),"Worst@50":round(worst50, 2),"Worst@100":round(worst100, 2),"Harmonic Mean":round(hacc, 2),"Geometric Mean":round(gacc, 2),"Mean Class Acc":round(macc, 2)}


def get_cmm_score(test_embeds,class_embeds,test_labels=None,worst_k=1):
    class_num=class_embeds.shape[0]
    logits=cosine_similarity(test_embeds,class_embeds)
    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        if mask.sum()==0:
            H_matrix[cls]=0
        else:
            samples_of_cls=logits[mask]
            H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return torch.mean(worst_k_cmm)


def get_all_cmm_score(test_embeds,class_embeds,test_labels=None,worst_k=1):
    class_num=class_embeds.shape[0]
    logits=cosine_similarity(test_embeds,class_embeds)
    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        if mask.sum()==0:
            H_matrix[cls]=0
        else:
            samples_of_cls=logits[mask]
            H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return Cmm


def get_cmm_score_modified(test_embeds,class_embeds,test_labels=None,worst_k=1):
    class_num=class_embeds.shape[0]
    logits=cosine_similarity(test_embeds,class_embeds)
    mean_logits=torch.mean(logits,0)
    d=mean_logits-torch.mean(mean_logits)
    logits=logits-d.repeat(logits.shape[0],1)
    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        if mask.sum()==0:
            H_matrix[cls]=0
        else:
            samples_of_cls=logits[mask]
            H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return torch.mean(worst_k_cmm)

def get_overall_cmm_score_modified(test_embeds,class_embeds,test_labels=None,worst_k=1):
    class_num=class_embeds.shape[0]
    logits=cosine_similarity(test_embeds,class_embeds)
    mean_logits=torch.mean(logits,0)
    d=mean_logits-torch.mean(mean_logits)
    logits=logits-d.repeat(logits.shape[0],1)
    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        samples_of_cls=logits[mask]
        H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return Cmm


def get_cmm_score_modified_2(test_embeds,class_embeds,test_labels=None,worst_k=1):
    class_num=class_embeds.shape[0]
    logits=cosine_similarity(test_embeds,class_embeds)
    mean_logits=torch.mean(logits,0)
    d=mean_logits-torch.mean(mean_logits)
    plogits=logits-d.repeat(logits.shape[0],1)
    if test_labels is None:
        pesudo_label=plogits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        samples_of_cls=logits[mask]
        H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return torch.mean(worst_k_cmm)


def get_cmm_score_zpe(test_embeds,class_embeds,selected_pretrain_embed_mean, test_labels=None,worst_k=1):
    class_num=class_embeds.shape[0]
    logits=cosine_similarity(test_embeds,class_embeds)

    # logits_pretrain = cosine_similarity(selected_pretrain_embed, class_embeds)

    # e_pretrain = logits_pretrain.mean(dim=0)
    e_pretrain = cosine_similarity(selected_pretrain_embed_mean.reshape(1, -1), class_embeds)

    e_test = logits.mean(dim=0)
    # print(e_pretrain, e_test)

    logits_normalized = logits - (e_pretrain + e_test) / 2
    
    logits = logits_normalized

    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        samples_of_cls=logits[mask]
        H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return torch.mean(worst_k_cmm)


def get_cmmPE_predict(dataset_name,backbone_name):
    test_embeds, test_labels = load_test_embeds_and_labels(dataset_name, backbone_name)
    all_templates=ZEROSHOT_TEMPLATES[dataset_name]
    print("len(all_templates)", len(all_templates))
    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)
    worst_k=all_class_embeds.shape[1]//10
    scores = [None] * len(all_class_embeds)
    for p, class_embeds in enumerate(all_class_embeds):
        scores[p]=get_cmm_score_modified(test_embeds,class_embeds,None,worst_k)  
    scores = torch.stack(scores).cuda().half()
    scores-=min(torch.min(scores),0)
    th=scores.median()
    score_filtered = scores * (scores >th)
    test_preds = cosine_similarity(test_embeds, all_class_embeds,score_filtered).argmax(dim=-1)
    return test_preds



def get_all_cmm_score_from_logits(logits, test_labels=None, worst_k=1):
    
    class_num = logits.shape[1]
    
    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        if mask.sum()==0:
            H_matrix[cls]=0
        else:
            samples_of_cls=logits[mask]
            H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return Cmm

def get_cmm_score_from_logits(logits, test_labels, worst_k=1):
    
    class_num = logits.shape[1]
    
    if test_labels is None:
        pesudo_label=logits.argmax(dim=-1)
    else:
        pesudo_label=test_labels
    H_matrix=torch.zeros(class_num,class_num)
    for cls in range(class_num):
        mask=(pesudo_label==cls)
        if mask.sum()==0:
            H_matrix[cls]=0
        else:
            samples_of_cls=logits[mask]
            H_matrix[cls]=torch.mean(samples_of_cls,0)
    Cmm=torch.zeros(class_num)
    for cls in range(class_num):
        diag=torch.clone(H_matrix[cls][cls])
        H_matrix[cls][cls]=-10000
        Cmm[cls]=diag-torch.max(H_matrix[cls])
    sorted_cmm, sorted_index = torch.sort(Cmm, descending=False)
    worst_k_cmm= sorted_cmm[:worst_k]
    return torch.mean(worst_k_cmm)


def calc_pred_from_logits(all_logits, weights):
    batch_size = 5000

    num_instance = all_logits.shape[1]
    ll = 0

    weighted_logits = []
    while ll < num_instance:
        rr = min(ll + batch_size, num_instance)

        all_logits_i = all_logits[:, ll: rr, :].cuda()
        # print(all_logits_i.shape)
        
        weighted_logits.append( (weights.reshape(-1, 1, 1) * all_logits_i).sum(dim=0))
        ll += batch_size

    weighted_logits = torch.cat(weighted_logits, dim=0)
    return weighted_logits.argmax(dim=1)



# CMM_DES

def get_CMM_DES_pred(dataset_name,backbone_name,test_embeds,test_labels,hand_craft=True, selection=False):

    if hand_craft:
        all_templates =ZEROSHOT_TEMPLATES[dataset_name]
    else:
        all_templates =list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))

    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)

    prompt_dict_file = os.path.join("cache_prompt_embed", dataset_name, backbone_name, "prompts_dict.pt")

    from prompt_engine import load_all_dict
    
    all_prompt_embed_dict = load_all_dict(save_file=prompt_dict_file)

    print("loading finished")


    worst_k=all_class_embeds.shape[1]//10

    scores = [None] * len(all_class_embeds)

    all_logits = []
    
    classname_list = ZEROSHOT_CLASS_NAMES[dataset_name]
    
  

    success_count, all_load_count = 0, 0

    for prefix_i in tqdm(range(len(all_templates))): 

        # worst_k=all_class_embeds.shape[1]//12
    
        score_over_class = torch.Tensor([])
    
        # scores[p]=get_cmm_score(test_embeds,class_embeds,None,worst_k)    
        # scores[p]=get_cmm_score(test_embeds,class_embeds,test_labels,worst_k)  

        logits = cosine_similarity(test_embeds, all_class_embeds[prefix_i])
        score_over_class=get_all_cmm_score_from_logits(logits,test_labels=None,worst_k=None)

        # _, worst_index_i = torch.sort(score_over_class, descending=False)

        sorted_cmm, sorted_index = torch.sort(score_over_class, descending=False)

        worst_index = sorted_index[:worst_k]

        logits_des = logits.clone()

        for c_i in worst_index:
            all_load_count += 1
            try:
            
                prefix_class_key = (all_templates[0], classname_list[c_i])
                embed_col = [all_prompt_embed_dict[prefix_class_key][key_i].reshape(1, -1) for key_i in all_prompt_embed_dict[prefix_class_key]]
                prompts_embed = torch.cat(embed_col, dim=0)


                prompts_embed = prompts_embed.to("cuda")

                similarity_matrix_chunk = cosine_similarity(test_embeds, prompts_embed)
                class_similarity = similarity_matrix_chunk.mean(dim=1)

                logits_des[:, c_i] = class_similarity

                # print("wrap class {} [idx: {}] with gpt description...".format(classname_list[c_i], c_i))

                success_count += 1
                
            except:
                pass

        all_logits.append(logits_des.reshape(1, logits_des.shape[0],logits_des.shape[1]).cpu())
        scores[prefix_i] = get_cmm_score_from_logits(logits_des,test_labels=None,worst_k=worst_k)  


    
    print("success rate: ", success_count / all_load_count)

    scores = torch.stack(scores).cuda().half()
    all_logits = torch.cat(all_logits, dim=0)
    # print(all_logits.shape)

    softmax_scores = torch.softmax(scores.reshape(1,-1),dim=1).reshape(-1)


    # print(softmax_scores)
    if selection:
        th=softmax_scores.median()

        # print(th)
        softmax_scores = softmax_scores * (softmax_scores >= th)

    # print(softmax_scores)

    test_preds = calc_pred_from_logits(all_logits=all_logits,weights=softmax_scores)

    return test_preds
