import os.path as osp

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from torchvision.models._utils import IntermediateLayerGetter

from tqdm import tqdm
import pickle5 as pickle

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

from .utils import soft_cross_entropy, softmax_sigmoid_BCEloss, \
    norm_logits_BCEloss, sigmoid_focal_loss, sigmoid_ASL_loss, ranking_loss, ASL_loss

_tokenizer = _Tokenizer()

import os
import matplotlib.pyplot as plt
import numpy as np
import json
import gc
from clip.model import AttentionPool2d

import pickle5 as pickle
from os.path import join
import json

def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
        self.token_embedding = clip_model.token_embedding

    def forward(self, prompts, tokenized_prompts, if_embedding=True, if_sequence=False):
        if not if_embedding:
            tokenized_prompts = prompts
            prompts = self.token_embedding(prompts).type(self.dtype)  # [batch_size, n_ctx, d_model]
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]

        if if_sequence:
            x = x @ self.text_projection  # NLD * Dd = NLd
            return x
        else:
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            # ND * Dd = Nd
            x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
            return x


class PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.Caption.N_CTX
        ctx_init = cfg.TRAINER.Caption.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if cfg.TRAINER.Caption.CSC:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)

            if cfg.TRAINER.Caption.CSC:
                print("Initializing class-specific double contexts")
                ctx_vectors_double = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors_double = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors_double, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f'Initial double context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized
        self.ctx_double = nn.Parameter(ctx_vectors_double)  # to be optimized

        temperature = torch.tensor(3.0, dtype=dtype)  # exp(3.91) = 50
        self.temperature = nn.Parameter(temperature)
        spatial_T = torch.tensor(3.0, dtype=dtype)  # 20
        self.spatial_T = nn.Parameter(spatial_T)
        ranking_scale = torch.tensor(4.0, dtype=dtype)  # 20
        self.ranking_scale = nn.Parameter(ranking_scale)

        # sigmoid_shift = torch.tensor(0.25, dtype=dtype)
        # self.sigmoid_shift = nn.Parameter(sigmoid_shift)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS

        # class agnostic token suffix
        prompts_nocls = [prompt_prefix + "."] * len(classnames)
        tokenized_prompts_nocls = torch.cat([clip.tokenize(p) for p in prompts_nocls])
        with torch.no_grad():
            embedding_nocls = clip_model.token_embedding(tokenized_prompts_nocls).type(dtype)
        self.register_buffer("token_suffix_nocls", embedding_nocls[:, 1 + n_ctx:, :])  # EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = cfg.TRAINER.Caption.CLASS_TOKEN_POSITION

    def forward(self, neg_prompt_wcls=True):
        """
        Returns current learned ctx embeddings, concated with cls word embeddings.
        """
        ctx = self.ctx
        ctx_double = self.ctx_double
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        if ctx_double.dim() == 2:
            ctx_double = ctx_double.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix
        suffix_nocls = self.token_suffix_nocls

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,  # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )
            if neg_prompt_wcls:
                prompts_neg = torch.cat(
                    [
                        prefix,  # (n_cls, 1, dim)
                        ctx_double,  # (n_cls, n_ctx, dim)
                        suffix,  # (n_cls, *, dim)
                    ],
                    dim=1,
                )
            else:
                prompts_neg = torch.cat(
                    [
                        prefix,  # (n_cls, 1, dim)
                        ctx_double,  # (n_cls, n_ctx, dim)
                        suffix_nocls,  # (n_cls, *, dim)
                    ],
                    dim=1,
                )


        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,  # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,  # (1, name_len, dim)
                        ctx_i,  # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts, prompts_neg, self.temperature, self.spatial_T, self.ranking_scale


def visualize_with_spatial_logits(batch_idx, image, spatial_logits, label, probs,
                                  save_path='./tai_overlay_images_local'):
    # print ("probs",probs.shape)
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]

    cell_size = 7

    probs = probs.detach().cpu().numpy()  # [100,80]
    image = image.detach().cpu().numpy()  # [100, 3, 224, 224]

    # Assuming spatial_logits has a shape of [7x7, 100, 80]
    label = label.detach().cpu().numpy()
    spatial_logits = spatial_logits.detach().cpu().numpy()
    spatial_logits = spatial_logits.reshape(cell_size, cell_size, image.shape[0], -1)  # [7, 7, 100, 80]
    spatial_logits = spatial_logits.transpose((2, 3, 0, 1))  # Transpose spatial_logits to [100, 80, 7, 7]

    # Convert spatial_logits to a PyTorch tensor
    spatial_logits = torch.tensor(spatial_logits, dtype=torch.float32)

    # Resize spatial_logits to match the dimensions of image
    resized_spatial_logits = F.interpolate(spatial_logits, size=(image.shape[2], image.shape[3]), mode='bilinear',
                                           align_corners=False)  # Transpose spatial_logits to [100, 80, 224,224]

    # Convert resized_spatial_logits back to a NumPy array
    resized_spatial_logits = resized_spatial_logits.numpy()

    # Create a directory to save overlay images
    if not os.path.exists(save_path + "/label"):
        os.makedirs(save_path + "/label")
    if not os.path.exists(save_path + "/top5"):
        os.makedirs(save_path + "/top5")

    # Visualize the overlay for each image in the batch
    # process labels

    for i in range(image.shape[0]):
        label_indexs = (label[i] == 1).nonzero()[0]
        print(label_indexs)

        for label_idx in label_indexs:
            fig, axes = plt.subplots(1, 2, figsize=(10, 10))
            clipped_image = np.clip(image[i], 0, 1)
            axes[0].imshow(np.transpose(clipped_image, (1, 2, 0)))

            axes[1].imshow(np.transpose(clipped_image, (1, 2, 0)))
            im = axes[1].imshow(resized_spatial_logits[i, label_idx], cmap='jet',
                                alpha=0.5)  # Overlay attention map with transparency
            fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
            axes[1].set_title(f'Prob:{probs[i][label_idx]} /Label {label_names[label_idx]} with PosLogits (Resized)')

            # Save the overlay image
            overlay_image_path = os.path.join(save_path + "/label",
                                              f"batch_{batch_idx}_image_{i}_label_{label_names[label_idx]}.png")
            plt.savefig(overlay_image_path, format='png')

            plt.close(fig)

    # process top5
    for i in range(image.shape[0]):
        top_5_indices = probs[i].argsort()[-5:][::-1]

        for label_idx in top_5_indices:
            fig, axes = plt.subplots(1, 2, figsize=(10, 10))
            clipped_image = np.clip(image[i], 0, 1)
            axes[0].imshow(np.transpose(clipped_image, (1, 2, 0)))

            axes[1].imshow(np.transpose(clipped_image, (1, 2, 0)))
            im = axes[1].imshow(resized_spatial_logits[i, label_idx], cmap='jet',
                                alpha=0.5)  # Overlay attention map with transparency
            fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
            axes[1].set_title(
                f'Prob:{probs[i][label_idx]} /{label_names[label_idx]} Label:{label[i][label_idx]} with PosLogits (Resized)')

            # Save the overlay image
            overlay_image_path = os.path.join(save_path + "/top5",
                                              f"batch_{batch_idx}_image_{i}_prob_{label_names[label_idx]}_{label[i][label_idx]}.png")
            plt.savefig(overlay_image_path, format='png')

            plt.close(fig)


def json2list(file, withclass=False):
    priors_list = []

    # JSON 파일 읽기
    with open(file, 'r') as f:
        data = json.load(f)

    # 원하는 형태의 딕셔너리로 데이터 저장
    prior_dict = {}
    for class_name, requests in data.items():
        descriptions = []
        for request, description_list in requests.items():

            if withclass:
                new_description_list = []
                for data in description_list:
                    new_description_list.append(data + " " + class_name)

                descriptions.extend(new_description_list)
            else:
                descriptions.extend(description_list)
        prior_dict[class_name] = descriptions

    # 결과 확인
    for class_name, descriptions in prior_dict.items():
        priors_list.append(descriptions[:])

    priors_list = np.array(priors_list)
    return priors_list


def write_bg_npy(json_file, text_encoder):
    prior_text_origin = json2list(json_file)  # [1,50]
    prompts = prior_text_origin[0]
    tokenized_prompts_list = [clip.tokenize(p) for p in prompts]
    tokenized_prompts = torch.stack(tokenized_prompts_list)
    tokenized_prompts = torch.squeeze(tokenized_prompts)
    print(tokenized_prompts.shape)  # [50,77]
    with torch.no_grad():
        text_features = text_encoder(tokenized_prompts, None, if_embedding=False, if_sequence=True)
        print(text_features.shape)  # [50,77,1024]

    out_name = json_file.split('/')[-1].split('.json')[0]

    np.save(out_name + '.npy', text_features.cpu().numpy())


def write_prior_npy(json_file, text_encoder, withclass=False):
    prior_text_origin = json2list(json_file, withclass)  # [80,50]
    prompts = prior_text_origin.flatten()  # [80X50]

    tokenized_prompts_list = [clip.tokenize(p) for p in prompts]
    tokenized_prompts = torch.stack(tokenized_prompts_list)
    tokenized_prompts = torch.squeeze(tokenized_prompts)
    print(tokenized_prompts.shape)  # [80x50,77]
    with torch.no_grad():
        text_features = text_encoder(tokenized_prompts, None, if_embedding=False, if_sequence=True)
        print(text_features.shape)  # [80x50,77,1024]
        prior_dict = text_features.reshape(prior_text_origin.shape[0], -1, text_features.shape[-2],
                                           text_features.shape[-1])
        # [80,50,77,1024]

    out_name = json_file.split('/')[-1].split('.json')[0]

    np.save(out_name + '.npy', prior_dict.cpu().numpy())


def read_npy(numpy_file, device):
    all_features = np.load(numpy_file)
    print('load complete:', numpy_file)
    return torch.tensor(all_features).to(device)


def write_prototype_npy(text_encoder):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]

    proto_prior_list = []
    for label in label_names:
        prototype_cls = 'a photo of ' + label
        proto_prior_list.append(prototype_cls)

    proto_prior_list = np.array(proto_prior_list)

    tokenized_prompts_list = [clip.tokenize(p) for p in proto_prior_list]
    tokenized_prompts = torch.stack(tokenized_prompts_list)
    tokenized_prompts = torch.squeeze(tokenized_prompts)
    print(tokenized_prompts.shape)  # [80,77]
    with torch.no_grad():
        text_features = text_encoder(tokenized_prompts, None, if_embedding=False, if_sequence=True)
        print(text_features.shape)  # [80,77,512]

    np.save('COCO_proto.npy', text_features.cpu().numpy())
    
    
    
    
################################################################## Low quality
def write_nus_llm_low_quality_caption(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    
    
    voc_classname_synonyms = ['airport',
 'air port',
 'air field',
 'runway',
 'animal',
 'beach',
 'plage',
 'coast',
 'seashore',
 'bear',
 'birds',
 'bird',
 'boats',
 'boat',
 'raft',
 'dinghy',
 'book',
 'bridge',
 'buildings',
 'building',
 'cars',
 'car',
 'castle',
 'cat',
 'kitty',
 'cityscape',
 'city',
 'skyscraper',
 'clouds',
 'cloud',
 'computer',
 'desktop',
 'laptop',
 'coral',
 'cow',
 'dancing',
 'dance',
 'dog',
 'pup',
 'puppy',
 'doggy',
 'earthquake',
 'collapse building',
 'break building',
 'broken building',
 'elk',
 'deer',
 'fire',
 'fish',
 'flags',
 'flag',
 'flowers',
 'flower',
 'food',
 'fox',
 'frost',
 'forsted',
 'garden',
 'glacier',
 'ice',
 'grass',
 'harbor',
 'port',
 'harbour',
 'horses',
 'horse',
 'house',
 'lake',
 'leaf',
 'map',
 'military',
 'army',
 'troops',
 'troop',
 'moon',
 'mountain',
 'hill',
 'nighttime',
 'night time',
 'night',
 'ocean',
 'sea',
 'person',
 'human',
 'people',
 'man',
 'woman',
 'passenger',
 'plane',
 'aeroplane',
 'air craft',
 'jet',
 'air plane',
 'plants',
 'plant',
 'police',
 'protest',
 'railroad',
 'rail road',
 'rail way',
 'rainbow',
 'reflection',
 'road',
 'path',
 'way',
 'rocks',
 'rock',
 'running',
 'run',
 'sand',
 'sign',
 'sky',
 'snow',
 'soccer',
 'football',
 'sports',
 'sport',
 'statue',
 'street',
 'sun',
 'sunset',
 'surf',
 'swimmers',
 'swimmer',
 'swimming',
 'swim',
 'tattoo',
 'tattooing',
 'temple',
 'tiger',
 'tower',
 'town',
 'toy',
 'train',
 'tree',
 'valley',
 'vehicle',
 'water',
 'waterfall',
 'wedding',
 'engagement',
 'bride',
 'groom',
 'whales',
 'whale',
 'window',
 'zebra']

    import os
    def read_name_list(path, if_split=True):
        ret = []
        with open(path, 'r') as f:
            for line in f:
                if if_split:
                    tmp = line.strip().split(' ')
                    ret.append(tmp[0])
                else:
                    tmp = line.strip()
                    ret.append(tmp)
        return ret
    voc_object_categories = read_name_list(os.path.join('/data/CVPR23_medfm/DATA/NUSWIDE', 'Concepts81.txt'), False) 
    

#     coco_root ='/data/CVPR23_medfm/DATA/COCO'
#     coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
#     caption_info = {}
#     with open(coco_caption_json_file, 'r') as f:
#         caption_info = json.load(f)

#     anno_id2path = {}
#     for i in caption_info["annotations"]:
#         anno_id2path[i["id"]] = i
        
#     with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
#         sample_capid = pickle.load(f)
    
    
    label_names = voc_object_categories
    label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    tokenizer_label_task = {}
    for i in label_names_task:
        tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    import json


        
    pa = '/data/CVPR23_medfm/DATA/priors/nus_wide_sentences_57000.json'
    with open(pa, 'rb') as f:
        prompts = json.load(f) 
        

    
    import random

    random.shuffle(prompts)
    
    # for ii in range(len(prompts1)):
    #     temp_sentence = clip.tokenize(prompts1[ii]['sentence'])
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
    overall_sen = []
    import json
    for ii in range(len(prompts)):
        overall_sen.append(prompts[ii])
        temp_sentence = clip.tokenize(prompts[ii])
        # for j in label_names_task:
        #     temp_token = tokenizer_label_task[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        task_list.append(temp_sentence)
        for j in label_names:
            temp_token = tokenizer_label[j]
            if temp_token in temp_sentence:
                caption_dict[j].append(temp_sentence)
                task_list.append(temp_sentence)    
    
    # for ii, p in enumerate(range(prompts.shape[0])):
    #     temp_sentence = prompts[ii, :]
    #     for j in label_names_task:
    #         temp_token = tokenizer_label_task[j]
    #         if temp_token in temp_sentence:
    #             #caption_dict[j].append(temp_sentence)
    #             task_list.append(temp_sentence)
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
    #             #task_list.append(temp_sentence)        
    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
        
        
#     ### class    
    
    text_features_cat1= []
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
        cap_ind = caption_dict_ind[j][:min_num]
        
        with torch.no_grad():
            text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
            text_features_cat1.append(text_features.cpu().numpy())

    np.save('nus_llm_caption_selection_rn50_56000_low_quality_sentence.npy', text_features_cat1)  
   
    ## task
    
    num_samples = len(prompts) # prompts.shape[0]
    
    
    with torch.no_grad():
        #random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:56000,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
        
        text_features = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[45000:50000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[50000:55000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[55000:], None, if_embedding=False, if_sequence=False).squeeze()

        
        text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3),0)
        
       # text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10),0)
        text_features = torch.cat((text_features_1,text_features_2),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:56000]
        
        text_features = text_features.reshape(1600,35,1024)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('nus_llm_multi_label_caption_selection_rn50_56000_low_quality.npy', saved_np)    
    
    
"""
"""

def write_coco_llm_low_quality_caption(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    

    import os
    def read_name_list(path, if_split=True):
        ret = []
        with open(path, 'r') as f:
            for line in f:
                if if_split:
                    tmp = line.strip().split(' ')
                    ret.append(tmp[0])
                else:
                    tmp = line.strip()
                    ret.append(tmp)
        return ret
#    voc_object_categories = read_name_list(os.path.join('/data/CVPR23_medfm/DATA/NUSWIDE', 'Concepts81.txt'), False) 
    

#     coco_root ='/data/CVPR23_medfm/DATA/COCO'
#     coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
#     caption_info = {}
#     with open(coco_caption_json_file, 'r') as f:
#         caption_info = json.load(f)

#     anno_id2path = {}
#     for i in caption_info["annotations"]:
#         anno_id2path[i["id"]] = i
        
#     with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
#         sample_capid = pickle.load(f)
    
    
   # label_names = voc_object_categories
  #  label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    # tokenizer_label_task = {}
    # for i in label_names_task:
    #     tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    import json


        
    pa = '/data/CVPR23_medfm/DATA/priors/coco_sentences_32000.json'
    with open(pa, 'rb') as f:
        prompts = json.load(f) 
        

    
    import random

    random.shuffle(prompts)
    
    # for ii in range(len(prompts1)):
    #     temp_sentence = clip.tokenize(prompts1[ii]['sentence'])
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
    overall_sen = []
    import json
    for ii in range(len(prompts)):
        overall_sen.append(prompts[ii])
        temp_sentence = clip.tokenize(prompts[ii])
        # for j in label_names_task:
        #     temp_token = tokenizer_label_task[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        task_list.append(temp_sentence)
        for j in label_names:
            temp_token = tokenizer_label[j]
            if temp_token in temp_sentence:
                caption_dict[j].append(temp_sentence)
                task_list.append(temp_sentence)    
    
    # for ii, p in enumerate(range(prompts.shape[0])):
    #     temp_sentence = prompts[ii, :]
    #     for j in label_names_task:
    #         temp_token = tokenizer_label_task[j]
    #         if temp_token in temp_sentence:
    #             #caption_dict[j].append(temp_sentence)
    #             task_list.append(temp_sentence)
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
    #             #task_list.append(temp_sentence)        
    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
        
        
#     ### class    
    
#     text_features_cat1= []
#     for iid, j in enumerate(label_names):
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
        
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

#     np.save('coco_llm_caption_selection_32000_low_quality_sentence.npy', text_features_cat1)  
   
    ## task
    
    num_samples = len(prompts) # prompts.shape[0]
    
    
    with torch.no_grad():
        #random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:32000,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6),0)

        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:32000]
        
        text_features = text_features.reshape(1600,20,1024)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('coco_llm_multi_label_caption_selection_32000_low_quality_rn50.npy', saved_np)    
    

    
    """
"""

def write_voc_llm_low_quality_caption(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    voc_object_categories = ['airplane', 'bicycle', 'bird', 'boat',
                        'bottle', 'bus', 'car', 'cat', 'chair',
                        'cow', 'dining table', 'dog', 'horse',
                        'motor bike', 'person', 'potted plant',
                        'sheep', 'sofa', 'train', 'tv']
    voc_classname_synonyms = ['aeroplane',
     'air craft',
     'jet',
     'plane',
     'air plane',
     'bicycle',
     'bike',
     'cycle',
     'bird',
     'boat',
     'raft',
     'dinghy',
     'bottle',
     'bus',
     'autobus',
     'coach',
     'charabanc',
     'double decker',
     'jitney',
     'motor bus',
     'motor coach',
     'omnibus',
     'car',
     'taxi',
     'auto',
     'automobile',
     'motor car',
     'cat',
     'kitty',
     'chair',
     'arm chair',
     'bench',
     'cow',
     'table',
     'dining table',
     'dinner table',
     'din table',
     'dog',
     'pup',
     'puppy',
     'doggy',
     'horse',
     'colt',
     'equus',
     'motor bike',
     'motor cycle',
     'person',
     'human',
     'people',
     'man',
     'woman',
     'passenger',
     'potted plant',
     'house plant',
     'bonsai',
     'pot plant',
     'sheep',
     'sofa',
     'couch',
     'train',
     'rail way',
     'railroad',
     'tvmonitor',
     'monitor',
     'tv',
     'television',
     'telly']

    import os
    def read_name_list(path, if_split=True):
        ret = []
        with open(path, 'r') as f:
            for line in f:
                if if_split:
                    tmp = line.strip().split(' ')
                    ret.append(tmp[0])
                else:
                    tmp = line.strip()
                    ret.append(tmp)
        return ret
#    voc_object_categories = read_name_list(os.path.join('/data/CVPR23_medfm/DATA/NUSWIDE', 'Concepts81.txt'), False) 
    

#     coco_root ='/data/CVPR23_medfm/DATA/COCO'
#     coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
#     caption_info = {}
#     with open(coco_caption_json_file, 'r') as f:
#         caption_info = json.load(f)

#     anno_id2path = {}
#     for i in caption_info["annotations"]:
#         anno_id2path[i["id"]] = i
        
#     with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
#         sample_capid = pickle.load(f)
    
    
    label_names = voc_object_categories
    label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    tokenizer_label_task = {}
    for i in label_names_task:
        tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    import json


        
    pa = '/data/CVPR23_medfm/DATA/priors/voc_sentences_40000.json'
    with open(pa, 'rb') as f:
        prompts = json.load(f) 
        

    
    import random
    
        
#     for ii, p in enumerate(range(prompts.shape[0])):
#         temp_sentence = prompts[ii, :]
#         for j in label_names_task:
#             temp_token = tokenizer_label_task[j]
#             if temp_token in temp_sentence:
#                 #caption_dict[j].append(temp_sentence)
#                 task_list.append(temp_sentence)
#         for j in label_names:
#             len_name = len(j.split(' '))
#             if len_name ==1:
#                 temp_token = tokenizer_label[j]
#                 if temp_token in temp_sentence:
#                     caption_dict[j].append(temp_sentence)
#                 #task_list.append(temp_sentence)        
#             elif len_name == 2:
#                 temp_token = tokenizer_label[j].squeeze()
#                 if temp_token[0] in temp_sentence:
#                     if temp_token[1] in temp_sentence:
#                         caption_dict[j].append(temp_sentence)
#     for l_n in label_names:
#         print(l_n,len(caption_dict[l_n]))
        

    random.shuffle(prompts)
    
    # for ii in range(len(prompts1)):
    #     temp_sentence = clip.tokenize(prompts1[ii]['sentence'])
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
    overall_sen = []
    import json
    for ii in range(len(prompts)):
        overall_sen.append(prompts[ii])
        temp_sentence = clip.tokenize(prompts[ii])
        t = 0
        for j in label_names_task:
            temp_token = tokenizer_label_task[j]
            if temp_token in temp_sentence:
                task_list.append(temp_sentence)  
                t += 1
             #   caption_dict[j].append(temp_sentence)
        
        
        task_list.append(temp_sentence)
        for j in label_names:
            temp_token = tokenizer_label[j]
            if temp_token in temp_sentence:
                caption_dict[j].append(temp_sentence)
           #     task_list.append(temp_sentence)    
                if t == 0:
                    task_list.append(temp_sentence)        
    
#     for ii, p in enumerate(range(prompts.shape[0])):
#         temp_sentence = prompts[ii, :]
#         t = 0
#         for j in label_names_task:
#             temp_token = tokenizer_label_task[j]
#             if temp_token in temp_sentence:
#                 #caption_dict[j].append(temp_sentence)
#                 task_list.append(temp_sentence)
#                 t += 1
#         for j in label_names:
#             temp_token = tokenizer_label[j]
#             if temp_token in temp_sentence:
#                 caption_dict[j].append(temp_sentence)
                
#                 if t == 0:
#                     task_list.append(temp_sentence)        
    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
        
        
#     ### class    
    
    text_features_cat1= []
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
        cap_ind = caption_dict_ind[j][:min_num]
        
        with torch.no_grad():
            text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
            text_features_cat1.append(text_features.cpu().numpy())

    np.save('voc_llm_caption_selection_40000_low_quality_sentence_w_syn.npy', text_features_cat1)  
   
    ## task
    
    num_samples = len(prompts) # prompts.shape[0]
    
    
    with torch.no_grad():
        #random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:40000,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)

        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:40000]
        
        text_features = text_features.reshape(2000,20,512)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('voc_llm_multi_label_caption_selection_40000_low_quality_w_syn.npy', saved_np)    
    
    
#####################################################################################################    
def write_multi_space_caption_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    
    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
        
    num_samples = len(sample_capid)
    
    import random
    from sklearn.cluster import KMeans
    
    random_samples = random.sample(range(1,num_samples), 32000)
    random_samples.sort()
    
    with torch.no_grad():
        tokenize_samples = []
        for ii, k in enumerate(random_samples):
            p = sample_capid[k]
            temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
            tokenize_samples.append(temp_sentence)

        tokenize_mat = torch.stack(tokenize_samples).squeeze().to(device) #1 5000 77
    
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:], None, if_embedding=False, if_sequence=False).squeeze()
        text_features = torch.cat((text_features,text_features1,text_features2,text_features3,text_features4,text_features5,text_features6),0)
        
        
    

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features.reshape(1600,20,1024)
        
    saved_np = text_features.cpu().numpy()  
#     num_cluster = 100    
#     kmeans = KMeans(n_clusters=num_cluster, random_state=0).fit(text_features.cpu().numpy()) 
    
#     # Get the cluster index for each data point
#     cluster_assignments = kmeans.labels_
    
#     saved_np = np.zeros((100, 20, 512))
    
#     for i in range(100):
#         target_value = i
#         indices = [index for index, value in enumerate(cluster_assignments) if value == target_value]
        
        
        
        
#         print(len(indices))
#         indices = indices[:50]
#     #    print(text_features[indices,:].shape)
#         for k in range(saved_np.shape[1]):
#             saved_np[i,k,:] = text_features[indices[k],:].cpu().numpy()
        
    
    
#    print(cluster_assignments)
 #   print(cluster_assignments.shape)
    

    #print(len(text_features_cat))
    #np.savez()
   # np.save('COCO_caption_selection_local_101.npy', text_features_)
    np.save('COCO_multi_label_caption_selection_50_32000.npy', saved_np)

def write_multi_space_nus_caption_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    
    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
        
    num_samples = len(sample_capid)
    
    import random
    from sklearn.cluster import KMeans
    
    random_samples = random.sample(range(1,num_samples), 7000)
    random_samples.sort()
    
    with torch.no_grad():
        tokenize_samples = []
        for ii, k in enumerate(random_samples):
            p = sample_capid[k]
            temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
            tokenize_samples.append(temp_sentence)

        tokenize_mat = torch.stack(tokenize_samples).squeeze().to(device) #1 5000 77
    
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
        text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features.reshape(100,70,1024)
        
    saved_np = text_features.cpu().numpy()  
#     num_cluster = 100    
#     kmeans = KMeans(n_clusters=num_cluster, random_state=0).fit(text_features.cpu().numpy()) 
    
#     # Get the cluster index for each data point
#     cluster_assignments = kmeans.labels_
    
#     saved_np = np.zeros((100, 20, 512))
    
#     for i in range(100):
#         target_value = i
#         indices = [index for index, value in enumerate(cluster_assignments) if value == target_value]
        
        
        
        
#         print(len(indices))
#         indices = indices[:50]
#     #    print(text_features[indices,:].shape)
#         for k in range(saved_np.shape[1]):
#             saved_np[i,k,:] = text_features[indices[k],:].cpu().numpy()
        
    
    
#    print(cluster_assignments)
 #   print(cluster_assignments.shape)
    

    #print(len(text_features_cat))
    #np.savez()
   # np.save('COCO_caption_selection_local_101.npy', text_features_)
    np.save('COCO_multi_label_caption_selection_50_v3.npy', saved_np)
    
def write_nus_llm_caption_no_random_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    
    
    

    voc_classname_synonyms = ['airport',
 'air port',
 'air field',
 'runway',
 'animal',
 'beach',
 'plage',
 'coast',
 'seashore',
 'bear',
 'birds',
 'bird',
 'boats',
 'boat',
 'raft',
 'dinghy',
 'book',
 'bridge',
 'buildings',
 'building',
 'cars',
 'car',
 'castle',
 'cat',
 'kitty',
 'cityscape',
 'city',
 'skyscraper',
 'clouds',
 'cloud',
 'computer',
 'desktop',
 'laptop',
 'coral',
 'cow',
 'dancing',
 'dance',
 'dog',
 'pup',
 'puppy',
 'doggy',
 'earthquake',
 'collapse building',
 'break building',
 'broken building',
 'elk',
 'deer',
 'fire',
 'fish',
 'flags',
 'flag',
 'flowers',
 'flower',
 'food',
 'fox',
 'frost',
 'forsted',
 'garden',
 'glacier',
 'ice',
 'grass',
 'harbor',
 'port',
 'harbour',
 'horses',
 'horse',
 'house',
 'lake',
 'leaf',
 'map',
 'military',
 'army',
 'troops',
 'troop',
 'moon',
 'mountain',
 'hill',
 'nighttime',
 'night time',
 'night',
 'ocean',
 'sea',
 'person',
 'human',
 'people',
 'man',
 'woman',
 'passenger',
 'plane',
 'aeroplane',
 'air craft',
 'jet',
 'air plane',
 'plants',
 'plant',
 'police',
 'protest',
 'railroad',
 'rail road',
 'rail way',
 'rainbow',
 'reflection',
 'road',
 'path',
 'way',
 'rocks',
 'rock',
 'running',
 'run',
 'sand',
 'sign',
 'sky',
 'snow',
 'soccer',
 'football',
 'sports',
 'sport',
 'statue',
 'street',
 'sun',
 'sunset',
 'surf',
 'swimmers',
 'swimmer',
 'swimming',
 'swim',
 'tattoo',
 'tattooing',
 'temple',
 'tiger',
 'tower',
 'town',
 'toy',
 'train',
 'tree',
 'valley',
 'vehicle',
 'water',
 'waterfall',
 'wedding',
 'engagement',
 'bride',
 'groom',
 'whales',
 'whale',
 'window',
 'zebra']

    import os
    def read_name_list(path, if_split=True):
        ret = []
        with open(path, 'r') as f:
            for line in f:
                if if_split:
                    tmp = line.strip().split(' ')
                    ret.append(tmp[0])
                else:
                    tmp = line.strip()
                    ret.append(tmp)
        return ret
    voc_object_categories = read_name_list(os.path.join('/data/CVPR23_medfm/DATA/NUSWIDE', 'Concepts81.txt'), False) 
    

#     coco_root ='/data/CVPR23_medfm/DATA/COCO'
#     coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
#     caption_info = {}
#     with open(coco_caption_json_file, 'r') as f:
#         caption_info = json.load(f)

#     anno_id2path = {}
#     for i in caption_info["annotations"]:
#         anno_id2path[i["id"]] = i
        
#     with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
#         sample_capid = pickle.load(f)
    
    
    label_names = voc_object_categories
    label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    tokenizer_label_task = {}
    for i in label_names_task:
        tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    import json
    
    with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized_open_images.pkl', 'rb') as f:
        prompts = pickle.load(f)
        

        
    pa = '/data/CVPR23_medfm/DATA/priors/NUSWIDE_40000.json'
    with open(pa, 'rb') as f:
        prompts = json.load(f) 
        
    pa = '/data/CVPR23_medfm/DATA/priors/NUSWIDE_oneclass_expension_1110.json'
    with open(pa, 'rb') as f:
        prompts1 = json.load(f)
        
#     ###coco
#     with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized.pkl', 'rb') as f:
#         coco_prompts = pickle.load(f)
        
    prompts = prompts + prompts1
    
    import random

    random.shuffle(prompts)
    
    # for ii in range(len(prompts1)):
    #     temp_sentence = clip.tokenize(prompts1[ii]['sentence'])
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
    overall_sen = []
    import json
    for ii in range(len(prompts)):
        overall_sen.append(prompts[ii]['sentence'])
        temp_sentence = clip.tokenize(prompts[ii]['sentence'])
        # for j in label_names_task:
        #     temp_token = tokenizer_label_task[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        task_list.append(temp_sentence)
        for j in label_names:
            temp_token = tokenizer_label[j]
            if temp_token in temp_sentence:
                caption_dict[j].append(temp_sentence)
                task_list.append(temp_sentence)    
    
    # for ii, p in enumerate(range(prompts.shape[0])):
    #     temp_sentence = prompts[ii, :]
    #     for j in label_names_task:
    #         temp_token = tokenizer_label_task[j]
    #         if temp_token in temp_sentence:
    #             #caption_dict[j].append(temp_sentence)
    #             task_list.append(temp_sentence)
        # for j in label_names:
        #     temp_token = tokenizer_label[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        #         #task_list.append(temp_sentence)        
    
#     min_num = 100000
#     for l_n in label_names:
#         if len(caption_dict[l_n]) < min_num :
#             min_num = len(caption_dict[l_n])
        
#         print(l_n,len(caption_dict[l_n]))
        
        
#     ### class    
    
#     text_features_cat1= []
#     for iid, j in enumerate(label_names):
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
        
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

  #  np.save('nus_llm_caption_selection_50_57600_sentence.npy', text_features_cat1)  
   
    ## task
    
    num_samples = len(prompts) # prompts.shape[0]
    
   # random_samples = random.sample(range(1,num_samples), 40000)
  #  random_samples.sort()
    
    ##coco
    
#     with torch.no_grad():
#       #  random.shuffle(task_list)
        
#       #  tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
    
#         tokenize_mat = coco_prompts[:40000,: ].to(device)
        
        
#         text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features_coco = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
    
    with torch.no_grad():
        #random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:57600,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
        
        text_features = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[45000:50000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[50000:55000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[55000:], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features4 = text_encoder(tokenize_mat[60000:65000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features5 = text_encoder(tokenize_mat[65000:70000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features6 = text_encoder(tokenize_mat[70000:75000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features7 = text_encoder(tokenize_mat[75000:80000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features8 = text_encoder(tokenize_mat[80000:85000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features9 = text_encoder(tokenize_mat[85000:90000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features10 = text_encoder(tokenize_mat[90000:95000], None, if_embedding=False, if_sequence=False).squeeze()
       # text_features9 = text_encoder(tokenize_mat[95000:90000], None, if_embedding=False, if_sequence=False).squeeze()
        
        text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3),0)
        
       # text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10),0)
        text_features = torch.cat((text_features_1,text_features_2),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:57600]
        
       # text_features = torch.cat((text_features,text_features_coco),0)
        
        text_features = text_features.reshape(1600,36,1024)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('nus_llm_multi_label_caption_selection_50_fig4_57600.npy', saved_np)    
    
        
    overall_sen_filtered = overall_sen[:57600] 
    
    with open('nus_llm_multi_label_caption_selection_50_fig4_57600_sentence.json', 'w') as f:
        json.dump(overall_sen_filtered, f)
      
    
def write_nus_llm_caption_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    
    
    

    voc_classname_synonyms = ['airport',
 'air port',
 'air field',
 'runway',
 'animal',
 'beach',
 'plage',
 'coast',
 'seashore',
 'bear',
 'birds',
 'bird',
 'boats',
 'boat',
 'raft',
 'dinghy',
 'book',
 'bridge',
 'buildings',
 'building',
 'cars',
 'car',
 'castle',
 'cat',
 'kitty',
 'cityscape',
 'city',
 'skyscraper',
 'clouds',
 'cloud',
 'computer',
 'desktop',
 'laptop',
 'coral',
 'cow',
 'dancing',
 'dance',
 'dog',
 'pup',
 'puppy',
 'doggy',
 'earthquake',
 'collapse building',
 'break building',
 'broken building',
 'elk',
 'deer',
 'fire',
 'fish',
 'flags',
 'flag',
 'flowers',
 'flower',
 'food',
 'fox',
 'frost',
 'forsted',
 'garden',
 'glacier',
 'ice',
 'grass',
 'harbor',
 'port',
 'harbour',
 'horses',
 'horse',
 'house',
 'lake',
 'leaf',
 'map',
 'military',
 'army',
 'troops',
 'troop',
 'moon',
 'mountain',
 'hill',
 'nighttime',
 'night time',
 'night',
 'ocean',
 'sea',
 'person',
 'human',
 'people',
 'man',
 'woman',
 'passenger',
 'plane',
 'aeroplane',
 'air craft',
 'jet',
 'air plane',
 'plants',
 'plant',
 'police',
 'protest',
 'railroad',
 'rail road',
 'rail way',
 'rainbow',
 'reflection',
 'road',
 'path',
 'way',
 'rocks',
 'rock',
 'running',
 'run',
 'sand',
 'sign',
 'sky',
 'snow',
 'soccer',
 'football',
 'sports',
 'sport',
 'statue',
 'street',
 'sun',
 'sunset',
 'surf',
 'swimmers',
 'swimmer',
 'swimming',
 'swim',
 'tattoo',
 'tattooing',
 'temple',
 'tiger',
 'tower',
 'town',
 'toy',
 'train',
 'tree',
 'valley',
 'vehicle',
 'water',
 'waterfall',
 'wedding',
 'engagement',
 'bride',
 'groom',
 'whales',
 'whale',
 'window',
 'zebra']

    import os
    def read_name_list(path, if_split=True):
        ret = []
        with open(path, 'r') as f:
            for line in f:
                if if_split:
                    tmp = line.strip().split(' ')
                    ret.append(tmp[0])
                else:
                    tmp = line.strip()
                    ret.append(tmp)
        return ret
    voc_object_categories = read_name_list(os.path.join('/data/CVPR23_medfm/DATA/NUSWIDE', 'Concepts81.txt'), False) 
    

#     coco_root ='/data/CVPR23_medfm/DATA/COCO'
#     coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
#     caption_info = {}
#     with open(coco_caption_json_file, 'r') as f:
#         caption_info = json.load(f)

#     anno_id2path = {}
#     for i in caption_info["annotations"]:
#         anno_id2path[i["id"]] = i
        
#     with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
#         sample_capid = pickle.load(f)
    
    
    label_names = voc_object_categories
    label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    tokenizer_label_task = {}
    for i in label_names_task:
        tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    
    with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized_open_images.pkl', 'rb') as f:
        prompts = pickle.load(f)
        

        
    pa = '/data/CVPR23_medfm/DATA/priors/NUSWIDE_40000.json'
    with open(pa, 'rb') as f:
        prompts = json.load(f) 
        
    pa = '/data/CVPR23_medfm/DATA/priors/NUSWIDE_oneclass_expension_1110.json'
    with open(pa, 'rb') as f:
        prompts1 = json.load(f)
        
#     ###coco
#     with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized.pkl', 'rb') as f:
#         coco_prompts = pickle.load(f)
        
    prompts = prompts1 #prompts + prompts1
    
    import random

    random.shuffle(prompts)
    
    # for ii in range(len(prompts1)):
    #     temp_sentence = clip.tokenize(prompts1[ii]['sentence'])
    #     for j in label_names:
    #         temp_token = tokenizer_label[j]
    #         if temp_token in temp_sentence:
    #             caption_dict[j].append(temp_sentence)
        
    for ii in range(len(prompts)):
        temp_sentence = clip.tokenize(prompts[ii]['sentence'])
        # for j in label_names_task:
        #     temp_token = tokenizer_label_task[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        task_list.append(temp_sentence)
        for j in label_names:
            temp_token = tokenizer_label[j]
            if temp_token in temp_sentence:
                caption_dict[j].append(temp_sentence)
                task_list.append(temp_sentence)    
    
    # for ii, p in enumerate(range(prompts.shape[0])):
    #     temp_sentence = prompts[ii, :]
    #     for j in label_names_task:
    #         temp_token = tokenizer_label_task[j]
    #         if temp_token in temp_sentence:
    #             #caption_dict[j].append(temp_sentence)
    #             task_list.append(temp_sentence)
        # for j in label_names:
        #     temp_token = tokenizer_label[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        #         #task_list.append(temp_sentence)        
    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
        
        
    ### class    
    
    text_features_cat1= []
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
        cap_ind = caption_dict_ind[j][:min_num]
        
        with torch.no_grad():
            text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
            text_features_cat1.append(text_features.cpu().numpy())

    np.save('nus_llm_caption_selection_101_57600_one_cls.npy', text_features_cat1)  

    ## task
    
    num_samples = len(prompts) # prompts.shape[0]
    
   # random_samples = random.sample(range(1,num_samples), 40000)
  #  random_samples.sort()
    
    ##coco
    
#     with torch.no_grad():
#       #  random.shuffle(task_list)
        
#       #  tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
    
#         tokenize_mat = coco_prompts[:40000,: ].to(device)
        
        
#         text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features_coco = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
    
    with torch.no_grad():
        random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:57600,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
        
        text_features = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[45000:50000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[50000:55000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[55000:], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features4 = text_encoder(tokenize_mat[60000:65000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features5 = text_encoder(tokenize_mat[65000:70000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features6 = text_encoder(tokenize_mat[70000:75000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features7 = text_encoder(tokenize_mat[75000:80000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features8 = text_encoder(tokenize_mat[80000:85000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features9 = text_encoder(tokenize_mat[85000:90000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features10 = text_encoder(tokenize_mat[90000:95000], None, if_embedding=False, if_sequence=False).squeeze()
       # text_features9 = text_encoder(tokenize_mat[95000:90000], None, if_embedding=False, if_sequence=False).squeeze()
        
        text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3),0)
        
       # text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10),0)
        text_features = torch.cat((text_features_1,text_features_2),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:57600]
        
       # text_features = torch.cat((text_features,text_features_coco),0)
        
        text_features = text_features.reshape(1600,36,512)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('nus_llm_multi_label_caption_selection_101_57600_one_cls.npy', saved_np)    
  
    
def write_nus_caption_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    
    
    
    voc_object_categories = ['airport',
 'animal',
 'beach',
 'bear',
 'birds',
 'boats',
 'book',
 'bridge',
 'buildings',
 'cars',
 'castle',
 'cat',
 'city',
 'clouds',
 'computer',
 'coral',
 'cow',
 'dancing',
 'dog',
 'broken building',
 'deer',
 'fire',
 'fish',
 'flags',
 'flowers',
 'food',
 'fox',
 'forsted',
 'garden',
 'ice',
 'grass',
 'harbor',
 'horses',
 'house',
 'lake',
 'leaf',
 'map',
 'military',
 'moon',
 'mountain',
 'nighttime',
 'ocean',
 'person',
 'plane',
 'plants',
 'police',
 'protest',
 'rail road',
 'rainbow',
 'reflection',
 'road',
 'rocks',
 'running',
 'sand',
 'sign',
 'sky',
 'snow',
 'soccer',
 'sports',
 'statue',
 'street',
 'sun',
 'sunset',
 'surf',
 'swim',
 'tattoo',
 'temple',
 'tiger',
 'tower',
 'town',
 'toy',
 'train',
 'tree',
 'valley',
 'vehicle',
 'water',
 'waterfall',
 'wedding',
 'whale',
 'window',
 'zebra']
    voc_classname_synonyms = ['airport',
 'air port',
 'air field',
 'runway',
 'animal',
 'beach',
 'plage',
 'coast',
 'seashore',
 'bear',
 'birds',
 'bird',
 'boats',
 'boat',
 'raft',
 'dinghy',
 'book',
 'bridge',
 'buildings',
 'building',
 'cars',
 'car',
 'castle',
 'cat',
 'kitty',
 'cityscape',
 'city',
 'skyscraper',
 'clouds',
 'cloud',
 'computer',
 'desktop',
 'laptop',
 'coral',
 'cow',
 'dancing',
 'dance',
 'dog',
 'pup',
 'puppy',
 'doggy',
 'earthquake',
 'collapse building',
 'break building',
 'broken building',
 'elk',
 'deer',
 'fire',
 'fish',
 'flags',
 'flag',
 'flowers',
 'flower',
 'food',
 'fox',
 'frost',
 'forsted',
 'garden',
 'glacier',
 'ice',
 'grass',
 'harbor',
 'port',
 'harbour',
 'horses',
 'horse',
 'house',
 'lake',
 'leaf',
 'map',
 'military',
 'army',
 'troops',
 'troop',
 'moon',
 'mountain',
 'hill',
 'nighttime',
 'night time',
 'night',
 'ocean',
 'sea',
 'person',
 'human',
 'people',
 'man',
 'woman',
 'passenger',
 'plane',
 'aeroplane',
 'air craft',
 'jet',
 'air plane',
 'plants',
 'plant',
 'police',
 'protest',
 'railroad',
 'rail road',
 'rail way',
 'rainbow',
 'reflection',
 'road',
 'path',
 'way',
 'rocks',
 'rock',
 'running',
 'run',
 'sand',
 'sign',
 'sky',
 'snow',
 'soccer',
 'football',
 'sports',
 'sport',
 'statue',
 'street',
 'sun',
 'sunset',
 'surf',
 'swimmers',
 'swimmer',
 'swimming',
 'swim',
 'tattoo',
 'tattooing',
 'temple',
 'tiger',
 'tower',
 'town',
 'toy',
 'train',
 'tree',
 'valley',
 'vehicle',
 'water',
 'waterfall',
 'wedding',
 'engagement',
 'bride',
 'groom',
 'whales',
 'whale',
 'window',
 'zebra']


#     coco_root ='/data/CVPR23_medfm/DATA/COCO'
#     coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
#     caption_info = {}
#     with open(coco_caption_json_file, 'r') as f:
#         caption_info = json.load(f)

#     anno_id2path = {}
#     for i in caption_info["annotations"]:
#         anno_id2path[i["id"]] = i
        
#     with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
#         sample_capid = pickle.load(f)
    
    
    label_names = voc_object_categories
    label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    tokenizer_label_task = {}
    for i in label_names_task:
        tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    
    with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized_open_images.pkl', 'rb') as f:
        prompts = pickle.load(f)
        
    
    
    for ii, p in enumerate(range(prompts.shape[0])):
        temp_sentence = prompts[ii, :]
        for j in label_names_task:
            temp_token = tokenizer_label_task[j]
            if temp_token in temp_sentence:
                #caption_dict[j].append(temp_sentence)
                task_list.append(temp_sentence)
        # for j in label_names:
        #     temp_token = tokenizer_label[j]
        #     if temp_token in temp_sentence:
        #         caption_dict[j].append(temp_sentence)
        #         #task_list.append(temp_sentence)        
    
    # for l_n in label_names:
    #     print(l_n,len(caption_dict[l_n]))
        
        
    ### class    
    
#     text_features_cat1= []
#     for iid, j in enumerate(label_names):
#         cap = torch.stack(caption_dict[j][:16]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:16]
        
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())
    
#     np.save('nus_caption_selection_101.npy', text_features_cat1)  
    
    ## task
    
    import random
    num_samples = prompts.shape[0]
    
    random_samples = random.sample(range(1,num_samples), 20000)
    random_samples.sort()
    
    with torch.no_grad():
        random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:20000,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:], None, if_embedding=False, if_sequence=False).squeeze()
        
        text_features = torch.cat((text_features,text_features1,text_features2,text_features3),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features.reshape(200,100,512)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('nus_multi_label_caption_selection_101_v4_20000.npy', saved_np)
    
def write_voc_caption_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
  #  device = torch.device('cuda:2')
    text_encoder = text_encoder.to(device)
    
    
    
    voc_object_categories = ['airplane', 'bicycle', 'bird', 'boat',
                        'bottle', 'bus', 'car', 'cat', 'chair',
                        'cow', 'dining table', 'dog', 'horse',
                        'motor bike', 'person', 'potted plant',
                        'sheep', 'sofa', 'train', 'tv']
    voc_classname_synonyms = ['aeroplane',
     'air craft',
     'jet',
     'plane',
     'air plane',
     'bicycle',
     'bike',
     'cycle',
     'bird',
     'boat',
     'raft',
     'dinghy',
     'bottle',
     'bus',
     'autobus',
     'coach',
     'charabanc',
     'double decker',
     'jitney',
     'motor bus',
     'motor coach',
     'omnibus',
     'car',
     'taxi',
     'auto',
     'automobile',
     'motor car',
     'cat',
     'kitty',
     'chair',
     'arm chair',
     'bench',
     'cow',
     'table',
     'dining table',
     'dinner table',
     'din table',
     'dog',
     'pup',
     'puppy',
     'doggy',
     'horse',
     'colt',
     'equus',
     'motor bike',
     'motor cycle',
     'person',
     'human',
     'people',
     'man',
     'woman',
     'passenger',
     'potted plant',
     'house plant',
     'bonsai',
     'pot plant',
     'sheep',
     'sofa',
     'couch',
     'train',
     'rail way',
     'railroad',
     'tvmonitor',
     'monitor',
     'tv',
     'television',
     'telly']


    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
    
    
    label_names = voc_object_categories
    label_names_task = voc_classname_synonyms
    
    tokenizer_label = {}
    for i in label_names:
        len_name = len(i.split(' '))
        if len_name == 1:
            tokenizer_label[i] = clip.tokenize(i)[:,1]
        elif len_name ==2:
            tokenizer_label[i] = clip.tokenize(i)[:,1:3]
        
    tokenizer_label_task = {}
    for i in label_names_task:
        tokenizer_label_task[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    
    with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized.pkl', 'rb') as f:
        prompts = pickle.load(f)
        
    
    
    for ii, p in enumerate(range(prompts.shape[0])):
        temp_sentence = prompts[ii, :]
        for j in label_names_task:
            temp_token = tokenizer_label_task[j]
            if temp_token in temp_sentence:
                #caption_dict[j].append(temp_sentence)
                task_list.append(temp_sentence)
        for j in label_names:
            len_name = len(j.split(' '))
            if len_name ==1:
                temp_token = tokenizer_label[j]
                if temp_token in temp_sentence:
                    caption_dict[j].append(temp_sentence)
                #task_list.append(temp_sentence)        
            elif len_name == 2:
                temp_token = tokenizer_label[j].squeeze()
                if temp_token[0] in temp_sentence:
                    if temp_token[1] in temp_sentence:
                        caption_dict[j].append(temp_sentence)
    for l_n in label_names:
        print(l_n,len(caption_dict[l_n]))
        
        
    ## class    
    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
        
        text_features_cat1= []
 #   min_num = 100
        
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:min_num]).squeeze().to(device) #100 1 77
        cap_ind = caption_dict_ind[j][:min_num]
        
        with torch.no_grad():
            text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
            text_features_cat1.append(text_features.cpu().numpy())
    
    
#     text_features_cat1= []
#     for iid, j in enumerate(label_names):
        
#         if len(caption_dict[j])> 100:
#             cap = torch.stack(caption_dict[j][:100]).squeeze() #.to(device) #100 1 77
#             cap_ind = caption_dict_ind[j][:100]
#         else:
#             cap = torch.stack(caption_dict[j][:]).squeeze()
#             repeats = 50
#             cap = cap.repeat(repeats, 1)
#             cap = cap[:100]
#             #cap_ind = caption_dict_ind[j][:100]
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())
            
            
            
    
    np.save('VOC_caption_selection_100_50.npy', text_features_cat1)  
    
    ## task
    
    import random
    num_samples = len(sample_capid)
    
    random_samples = random.sample(range(1,num_samples), 40000)
    random_samples.sort()
    print('lelelelele',len(task_list))
    
    with torch.no_grad():
        # tokenize_samples = []
        # for ii, k in enumerate(random_samples):
        #     p = sample_capid[k]
        #     temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
        #     tokenize_samples.append(temp_sentence)
        random.shuffle(task_list)
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
      #  tokenize_mat = tokenize_mat[:7000,: ]
        
        
        
        
        
        
        tokenize_mat = tokenize_mat[:40000,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
        
#         text_features = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features1 = text_encoder(tokenize_mat[45000:50000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features2 = text_encoder(tokenize_mat[50000:55000], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features3 = text_encoder(tokenize_mat[55000:], None, if_embedding=False, if_sequence=False).squeeze()
#         text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3),0)
        
        text_features = text_features_1 #torch.cat((text_features_1,text_features_2),0)
        
        #text_features = torch.cat((text_features,text_features1,text_features2,text_features3,text_features4,text_features5,text_features6,text_features7),0)
        
        
    
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
  #      text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features.reshape(200,200,1024)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('VOC_multi_label_caption_selection_50_v5_40000.npy', saved_np)
    
    
def write_caption_embedding_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder.to(device)
    import json
    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    filtered_sentence = []
    
    overall_token = []
    overall_sen = []
    for ii, p in enumerate(sample_capid):
        temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
        for j in label_names:
            temp_token = tokenizer_label[j]

            if temp_token in temp_sentence[0,:]:
                overall_sen.append(anno_id2path[p]['caption'])
                overall_token.append(temp_sentence)

                
    for l_n in label_names:
        print(l_n,len(caption_dict[l_n]))
    text_features_cat1= []
    
    print('len',len(overall_sen))
    with torch.no_grad():
        # tokenize_samples = []
        # for ii, k in enumerate(random_samples):
        #     p = sample_capid[k]
        #     temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
        #     tokenize_samples.append(temp_sentence)
   #     random.shuffle(task_list)
        tokenize_mat = torch.stack(overall_token).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:40000,: ]
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0) 
        text_features = text_features.reshape(40000,512)
        
        
    np.save('COCO_caption_selection_fig4_40000_101.npy', text_features.cpu().numpy())        
    overall_sen_filtered = overall_sen[:40000] 
    
    with open('COCO_caption_selection_fig4_40000_sentence_101.json', 'w') as f:
        json.dump(overall_sen_filtered, f)
    

    
def write_caption_npy(text_encoder, device):
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]
    
    text_encoder = text_encoder #.to(device)
    
    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
    
    tokenizer_label = {}
    for i in label_names:
        tokenizer_label[i] = clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []

    for ii, p in enumerate(sample_capid):
        temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
        for j in label_names:
            temp_token = tokenizer_label[j]

            if temp_token in temp_sentence[0,:]:

                a=(temp_sentence[0,:]== temp_token).nonzero(as_tuple=True)

                index = a[0][0]
                caption_dict[j].append(temp_sentence)
                caption_dict_ind[j].append(index)
                
    for l_n in label_names:
        print(l_n,len(caption_dict[l_n]))
    text_features_cat1= []
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:16]).squeeze() #.to(device) #100 1 77
        cap_ind = caption_dict_ind[j][:16]
        
        with torch.no_grad():
            text_features = text_encoder(cap, None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
            text_features_cat1.append(text_features.cpu().numpy())
            
#             if iid ==0:
#                 text_features = text_encoder(cap, None, if_embedding=False, if_sequence=True).unsqueeze(0) 
                
#                # print(text_features_.shape)
#             else:
#                 text_features = text_encoder(cap, None, if_embedding=False, if_sequence=True).unsqueeze(0) 
#                 for iiid, tt in enumerate(cap_ind):
#                     text_features_cat.append(text_features[:,tt,:])
#                 text_features_  = torch.cat((torch.stack(text_features_cat).squeeze().unsqueeze(0),text_features_),0)
                
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:16]).squeeze() #100 1 77
        cap_ind = caption_dict_ind[j][:16]
        text_features_cat = []
        with torch.no_grad():
            if iid ==0:
                text_features = text_encoder(cap, None, if_embedding=False, if_sequence=True).unsqueeze(0) 
                print('text_features',text_features.shape)
                for iiid, tt in enumerate(cap_ind):
                    text_features_cat.append(text_features[:,iiid,tt,:])
                text_features_ = torch.stack(text_features_cat).squeeze().unsqueeze(0)
                print(text_features_.shape)
            else:
                text_features = text_encoder(cap, None, if_embedding=False, if_sequence=True).unsqueeze(0) 
                for iiid, tt in enumerate(cap_ind):
                    text_features_cat.append(text_features[:,iiid,tt,:])
                text_features_  = torch.cat((torch.stack(text_features_cat).squeeze().unsqueeze(0),text_features_),0)
   # 
    print(text_features_.shape)
    #print(len(text_features_cat))
    #np.savez()
    np.save('COCO_caption_selection_local_101.npy', text_features_)
    np.save('COCO_caption_selection_101.npy', text_features_cat1)
    #80 100 1024
        
#     for idx, k in enumerate(cap_ind):
#         caption_dict[j][idx]
#     print(torch.stack(caption_dict[j][:100])[:,:,k].shape)
        
        
#     caption_dict = {}
#     for l_n in label_names:
#         caption_dict[l_n] = []

#     for ii, p in enumerate(sample_capid):
#         with torch.no_grad():
#             temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
#             text_features = text_encoder(temp_sentence, None, if_embedding=False, if_sequence=True)
#         for j in label_names:
#             temp_token = tokenizer_label[j]

#             if temp_token in temp_sentence[0,:]:

#                 a=(temp_sentence[0,:]== temp_token).nonzero(as_tuple=True)

#                 index = a[0][0]
#                 with torch.no_grad():
#                     temp_text_feat = text_features[:,index,:]
#                     print(temp_text_feat.shape)
#                     if len(caption_dict[j]) == 0:
#                         caption_dict[j] = temp_text_feat
#                     elif len(caption_dict[j]) > 100:
#                         continue
#                     else:
#                         print(caption_dict[j].shape)
#                         print(temp_text_feat.shape)
#                         caption_dict[j]= torch.cat((temp_text_feat,caption_dict[j]),0)
#                     #caption_dict[j].append(text_features[:,index,:])
#     for j in label_names:
#         print(caption_dict[j].shape)

#     proto_prior_list = []
#     for label in label_names:
#         prototype_cls = 'a photo of ' + label
#         proto_prior_list.append(prototype_cls)

#     proto_prior_list = np.array(proto_prior_list)

#     tokenized_prompts_list = [clip.tokenize(p) for p in proto_prior_list]
#     tokenized_prompts = torch.stack(tokenized_prompts_list)
#     tokenized_prompts = torch.squeeze(tokenized_prompts)
#     print(tokenized_prompts.shape)  # [80,77]
#     with torch.no_grad():
#         text_features = text_encoder(tokenized_prompts, None, if_embedding=False, if_sequence=True)
#         print(text_features.shape)  # [80,77,512]

#     np.save('COCO_proto.npy', text_features.cpu().numpy())


######################
def write_caption_syn_npy(text_encoder, device):
    coco_classname_synonyms = [
    ['person', 'human', 'people', 'man', 'woman', 'passenger'], 
    ['bicycle', 'bike', 'cycle'],
    ['car', 'taxi', 'auto', 'automobile', 'motor car'], 
    ['motor cycle','motor bike', 'motorcycle'], 
    ['aeroplane','airplane', "air craft", "jet", "plane", "air plane"], 
    ['bus', 'autobus', 'coach', 'charabanc', 'double decker', 'jitney', 'motor bus', 'motor coach', 'omnibus'],
    ['train', 'rail way', 'railroad'], 
    ['truck'],
    ['boat', 'raft', 'dinghy'],
    ['traffic light'],
    ['fire hydrant', 'fire tap', 'hydrant'],
    ['stop sign', 'halt sign'],
    ['parking meter'],
    ['bench'],
    ['bird'],
    ['cat', 'kitty'],
    ['dog', 'pup', 'puppy', 'doggy'],
    ['horse', 'colt', 'equus'],
    ['sheep'],
    ['cow'],
    ['elephant'],
    ['bear'],
    ['zebra'],
    ['giraffe', 'camelopard'],
    ['backpack', 'back pack', 'knapsack', 'packsack', 'rucksack', 'haversack'],
    ['umbrella'],
    ['handbag', 'hand bag', 'pocketbook', 'purse'],
    ['tie', 'necktie'],
    ['suitcase'],
    ['frisbee'],
    ['skis', 'ski'],
    ['snowboard'],
    ['sports ball', 'sport ball', 'ball', 'football', 'soccer', 'tennis', 'basketball', 'baseball'],
    ['kite'],
    ['baseball bat', 'baseball game bat'],
    ['baseball glove', 'baseball mitt', 'baseball game glove'],
    ['skateboard'],
    ['surfboard'],
    ['tennis racket'],
    ['bottle'],
    ['wine glass', 'vino glass'],
    ['cup'],
    ['fork'],
    ['knife'],
    ['spoon'],
    ['bowl'],
    ['banana'],
    ['apple'],
    ['sandwich'],
    ['orange'],
    ['broccoli'],
    ['carrot'],
    ['hot dog'],
    ['pizza'],
    ['donut', 'doughnut'],
    ['cake'],
    ['chair', 'arm chair'],
    ['couch', 'sofa'],
    ['potted plant', 'house plant', 'bonsai', 'pot plant'],
    ['bed'],
    ['dining table', 'dinner table', 'table', 'din table'], 
    ['toilet', 'commode'],
    ['tv', 'tvmonitor', 'monitor', 'television', 'telly'],
    ['laptop'],
    ['mouse'],
    ['remote'],
    ['keyboard'],
    ['cell phone', 'phone', 'mobile phone'],
    ['microwave'],
    ['oven', 'roaster'],
    ['toaster'],
    ['sink'],
    ['refrigerator', 'icebox'],
    ['book'],
    ['clock'],
    ['vase'],
    ['scissors'],
    ['teddy bear', 'teddy'],
    ['hair drier', 'blowing machine', 'hair dryer', 'dryer', 'blow dryer', 'blown dry', 'blow dry'],
    ['toothbrush'],
    ]

    label_names = [syn[0] for syn in coco_classname_synonyms]
    
    
    text_encoder = text_encoder.to(device)
    
    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
    
    tokenizer_label = {}
    for ind, i in enumerate(label_names):
        syn_lists = []
        for j in coco_classname_synonyms[ind]:
            syn_lists.append(clip.tokenize(i)[:,1])
        tokenizer_label[i] = syn_lists #clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    
    for ii, p in enumerate(sample_capid):
        temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
        for j in label_names:
            temp_token = tokenizer_label[j]
            
            for k in temp_token:
                if k in temp_sentence[0,:]:
                    a=(temp_sentence[0,:]== k).nonzero(as_tuple=True)

                    index = a[0][0]
                    caption_dict[j].append(temp_sentence)
                    caption_dict_ind[j].append(index)
                    task_list.append(temp_sentence)   
                    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
        
        
        
        
   ###########
#     pa = '/data/CVPR23_medfm/TaI-DPT/all_combinations_single_cls_multi_template_coco.json'
#     with open(pa, 'rb') as f:
#         prompts = json.load(f) 
        
#     for ii in range(len(prompts)):
#         temp_sentence = clip.tokenize(prompts[ii])
#         for j in label_names:
#             temp_token = tokenizer_label[j]
            
#             for k in temp_token:
#                 if k in temp_sentence[0,:]:
#                     a=(temp_sentence[0,:]== k).nonzero(as_tuple=True)

#                     index = a[0][0]
#                     caption_dict[j].append(temp_sentence)
#                     caption_dict_ind[j].append(index)
#                     task_list.append(temp_sentence)   
                    
#     min_num = 100000
#     for l_n in label_names:
#         if len(caption_dict[l_n]) < min_num :
#             min_num = len(caption_dict[l_n])
        
#         print(l_n,len(caption_dict[l_n]))
        

##############

        
        
        
        
    ### original for class-related descriptions    
#     text_features_cat1= []
#     for iid, j in enumerate(label_names):
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
#       #  print(cap.shape)
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

#     ##original        
#   #  np.save('COCO_caption_selection_101_syn.npy', text_features_cat1)  
    
#     ## toy exp
#     np.save('COCO_caption_selection_101_syn_single_cls_v2.npy', text_features_cat1)
    
    ###duplication for class-related descriptions
#     text_features_cat1= []
#     min_num = 500
#     for iid, j in enumerate(label_names):
#         if len(caption_dict[j]) < 500:
#             while len(caption_dict[j]) < 500:
#                 caption_dict[j] += caption_dict[j][:500 - len(caption_dict[j])]
        
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
#       #  print(cap.shape)
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

#     np.save('COCO_caption_selection_50_syn_class_rel_dup_500.npy', text_features_cat1)      
    
    
    
#######################################################################################    
    
    import random
    #######
    with torch.no_grad():
        random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:32000,: ]
        
        
        # text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features8 = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features9 = text_encoder(tokenize_mat[45000:], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9),0)
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features8 = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features9 = text_encoder(tokenize_mat[45000:], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6),0)
        
    
       # text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10),0)
        text_features = text_features_1 #torch.cat((text_features_1,text_features_2),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:32000]
        
        text_features = text_features.reshape(1600,20,512)
        
        
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('COCO_multi_label_caption_selection_101_32000_syn.npy', saved_np)    

    
    
######################
def write_caption_voc_syn_npy(text_encoder, device):
    coco_classname_synonyms  = [
    ['airplane','aeroplane', "air craft", "jet", "plane", "air plane"], 
    ['bicycle', 'bike', 'cycle'], 
    ['bird'], 
    ['boat', 'raft', 'dinghy'],
    ['bottle'], 
    ['bus', 'autobus', 'coach', 'charabanc', 'double decker', 'jitney', 'motor bus', 'motor coach', 'omnibus'], 
    ['car', 'taxi', 'auto', 'automobile', 'motor car'], 
    ['cat', 'kitty'], 
    ['chair', 'arm chair', 'bench'],
    ['cow'], 
    ['dining table','diningtable', 'table', 'dinner table', 'din table'],  
    ['dog', 'pup', 'puppy', 'doggy'], 
    ['horse', 'colt', 'equus'],
    ['motor bike','motorbike', 'motor cycle'], 
    ['person', 'human', 'people', 'man', 'woman', 'passenger'], 
    ['potted plant','pottedplant', 'house plant', 'bonsai', 'pot plant'],
    ['sheep'], 
    ['sofa', 'couch'], 
    ['train', 'rail way', 'railroad'], 
    ['tv','tvmonitor', 'monitor', 'television', 'telly']
]

    label_names = [syn[0] for syn in coco_classname_synonyms]
    
    
    text_encoder = text_encoder.to(device)
    
    coco_root ='/data/CVPR23_medfm/DATA/COCO'
    coco_caption_json_file = os.path.join(coco_root, "annotations/captions_train2017.json")
    caption_info = {}
    with open(coco_caption_json_file, 'r') as f:
        caption_info = json.load(f)

    anno_id2path = {}
    for i in caption_info["annotations"]:
        anno_id2path[i["id"]] = i
        
    with open('/data/CVPR23_medfm/TaI-DPT/coco_caption_text_embed_sampled_idx.pkl', 'rb') as f:
        sample_capid = pickle.load(f)
    
    tokenizer_label = {}
    for ind, i in enumerate(label_names):
        syn_lists = []
        for j in coco_classname_synonyms[ind]:
            syn_lists.append(clip.tokenize(i)[:,1])
        tokenizer_label[i] = syn_lists #clip.tokenize(i)[:,1]
        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    text_set = []
    for ii, p in enumerate(sample_capid):
        temp_sentence = clip.tokenize(anno_id2path[p]['caption'])
        for j in label_names:
            temp_token = tokenizer_label[j]
            
            for k in temp_token:
                if k in temp_sentence[0,:]:
                    a=(temp_sentence[0,:]== k).nonzero(as_tuple=True)

                    index = a[0][0]
                    caption_dict[j].append(temp_sentence)
                    caption_dict_ind[j].append(index)
                    task_list.append(temp_sentence)  
                    text_set.append(anno_id2path[p]['caption'])
                    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
    
    ###original
#     text_features_cat1= []
#     for iid, j in enumerate(label_names):
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
#       #  print(cap.shape)
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

#     np.save('VOC_caption_selection_100_101_syn.npy', text_features_cat1)  

    ###duplication for class-related descriptions
#     text_features_cat1= []
#     min_num = 500
#     for iid, j in enumerate(label_names):
#         if len(caption_dict[j]) < 500:
#             while len(caption_dict[j]) < 500:
#                 caption_dict[j] += caption_dict[j][:500 - len(caption_dict[j])]
        
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
#       #  print(cap.shape)
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

#     np.save('VOC_caption_selection_100_101_syn_class_rel_dup_500.npy', text_features_cat1)  
    
    
    import random
    #######
    with torch.no_grad():
       # random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:40000,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features8 = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features9 = text_encoder(tokenize_mat[45000:50000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features10 = text_encoder(tokenize_mat[50000:55000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features11 = text_encoder(tokenize_mat[55000:60000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features12 = text_encoder(tokenize_mat[60000:65000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features13 = text_encoder(tokenize_mat[65000:70000], None, if_embedding=False, if_sequence=False).squeeze()
        # text_features14 = text_encoder(tokenize_mat[70000:], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7),0)
        
    
       # text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10),0)
        text_features = text_features_1 #torch.cat((text_features_1,text_features_2),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:40000]
        
        text_features = text_features.reshape(200,200,512)
       # text_features = text_features.reshape(360,200,512)
    
  #  import json
    text_set = text_set[:40000]
    with open('/data/CVPR23_medfm/TaI-DPT/text_set_voc.json', 'w') as outfile:
        json.dump(text_set, outfile, indent=4)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('VOC_multi_label_caption_selection_101_v5_40000_syn_supple.npy', saved_np)    
    
    
def write_caption_nus_syn_npy(text_encoder, device):
    coco_classname_synonyms  =[
['airport', 'air port', 'air field', 'runway'],
['animal'],
['beach', 'plage', 'coast', 'seashore'],
['bear'],
['birds', 'bird'],
['boats', 'boat', 'raft', 'dinghy'],
['book'],
['bridge'],
['buildings', 'building'],
['cars', 'car'],
['castle'],
['cat', 'kitty'],
['city','cityscape',  'skyscraper'],
['clouds', 'cloud'],
['computer', 'desktop', 'laptop'],
['coral'],
['cow'],
['dancing', 'dance'],
['dog', 'pup', 'puppy', 'doggy'],
['broken building','earthquake', 'collapse building', 'break building'],
['deer', 'elk'],
['fire'],
['fish'],
['flags', 'flag'],
['flowers','flower'],
['food'],
['fox'],
['forsted','frost'],  # 'ice' 'frost'
['garden'],
['ice','glacier'], # 'iceberg'
['grass'],
['harbor', 'port', 'harbour'],
['horses', 'horse'],
['house'],
['lake'],
['leaf'],
['map'],
['military', 'army' , 'troops', 'troop'],
['moon'],
['mountain', 'hill'],
['nighttime', 'night time', 'night'],
['ocean', 'sea'],
['person', 'human', 'people', 'man', 'woman', 'passenger'],
['plane', 'aeroplane', "air craft", "jet", "air plane"],
['plants', 'plant'],
['police'],
['protest'],
['rail road','railroad', 'rail way'],
['rainbow'],
['reflection'],
['road', 'path', 'way'],
['rocks', 'rock'],
['running', 'run'],
['sand'],
['sign'],
['sky'],
['snow'],
['soccer', 'football'],
['sports', 'sport'],
['statue'],
['street'],
['sun'],
['sunset'],
['surf'],
['swim', 'swimmers', 'swimming','swimmer'],
['tattoo', 'tattooing'],
['temple'],
['tiger'],
['tower'],
['town'],
['toy'],
['train'],
['tree'],
['valley'],
['vehicle'],
['water'],
['waterfall'],
['wedding', 'engagement', 'bride', 'groom'],
['whale', 'whales'],
['window'],
['zebra'],
]

    label_names = [syn[0] for syn in coco_classname_synonyms]
    
    
    text_encoder = text_encoder.to(device)
    

    
    tokenizer_label = {}
    for ind, i in enumerate(label_names):
        syn_lists = []
        for j in coco_classname_synonyms[ind]:
            syn_lists.append(clip.tokenize(i)[:,1])
        tokenizer_label[i] = syn_lists #clip.tokenize(i)[:,1]
        
        
        
    # #####
    # with open('/data/CVPR23_medfm/TaI-DPT/all_caption_tokenized_open_images.pkl', 'rb') as f:
    #     prompts = pickle.load(f)
        
        
    pa = '/data/CVPR23_medfm/DATA/priors/NUSWIDE_40000.json'
    with open(pa, 'rb') as f:
        prompts = json.load(f) 
        
    pa = '/data/CVPR23_medfm/DATA/priors/NUSWIDE_oneclass_expension_1110.json'
    with open(pa, 'rb') as f:
        prompts1 = json.load(f)
        
    prompts = prompts + prompts1
        
    import random

    random.shuffle(prompts)

        
    caption_dict = {}
    caption_dict_ind = {}
    
    task_list = []
    for l_n in label_names:
        caption_dict[l_n] = []

    for l_n in label_names:
        caption_dict_ind[l_n] = []
        
    
    for ii, p in enumerate(range(len(prompts))):
        temp_sentence = clip.tokenize(prompts[ii]['sentence'])
        for j in label_names:
            temp_token = tokenizer_label[j]
            
            for k in temp_token:
                if k in temp_sentence[:]:
                    a=(temp_sentence[:]== k).nonzero(as_tuple=True)

                    index = a[0][0]
                    caption_dict[j].append(temp_sentence)
                    caption_dict_ind[j].append(index)
                    task_list.append(temp_sentence)   
                    
    min_num = 100000
    for l_n in label_names:
        if len(caption_dict[l_n]) < min_num :
            min_num = len(caption_dict[l_n])
        
        print(l_n,len(caption_dict[l_n]))
    
    ###original
    text_features_cat1= []
    for iid, j in enumerate(label_names):
        cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
        cap_ind = caption_dict_ind[j][:min_num]
      #  print(cap.shape)
        with torch.no_grad():
            text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
            text_features_cat1.append(text_features.cpu().numpy())

    np.save('nus_llm_caption_selection_50_57600_one_cls_syn.npy', text_features_cat1)  

    ###duplication for class-related descriptions
#     text_features_cat1= []
#     min_num = 500
#     for iid, j in enumerate(label_names):
#         if len(caption_dict[j]) < 500:
#             while len(caption_dict[j]) < 500:
#                 caption_dict[j] += caption_dict[j][:500 - len(caption_dict[j])]
        
#         cap = torch.stack(caption_dict[j][:min_num]).squeeze() #.to(device) #100 1 77
#         cap_ind = caption_dict_ind[j][:min_num]
#       #  print(cap.shape)
#         with torch.no_grad():
#             text_features = text_encoder(cap.to(device), None, if_embedding=False, if_sequence=False) #.unsqueeze(0) 
#             text_features_cat1.append(text_features.cpu().numpy())

#     np.save('nus_llm_caption_selection_50_57600_one_cls_syn_class_rel_dup_500.npy', text_features_cat1)  
    
    
    import random
    #######
    with torch.no_grad():
        random.shuffle(task_list)
        
        tokenize_mat = torch.stack(task_list).squeeze().to(device) #1 5000 77
        tokenize_mat = tokenize_mat[:57600,: ]
        
        
        text_features = text_encoder(tokenize_mat[:5000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features1 = text_encoder(tokenize_mat[5000:10000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features2 = text_encoder(tokenize_mat[10000:15000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features3 = text_encoder(tokenize_mat[15000:20000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features4 = text_encoder(tokenize_mat[20000:25000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features5 = text_encoder(tokenize_mat[25000:30000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features6 = text_encoder(tokenize_mat[30000:35000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features7 = text_encoder(tokenize_mat[35000:40000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features8 = text_encoder(tokenize_mat[40000:45000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features9 = text_encoder(tokenize_mat[45000:50000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features10 = text_encoder(tokenize_mat[50000:55000], None, if_embedding=False, if_sequence=False).squeeze()
        text_features11 = text_encoder(tokenize_mat[55000:], None, if_embedding=False, if_sequence=False).squeeze()
        text_features_1 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10,text_features11),0)
        
    
       # text_features_2 = torch.cat((text_features,text_features1,text_features2,text_features3, text_features4,text_features5,text_features6,text_features7,text_features8,text_features9,text_features10),0)
        text_features = text_features_1 #torch.cat((text_features_1,text_features_2),0)
        
        print('tokenize_mat',tokenize_mat.shape) #([5000, 77])
        
       # text_features = text_encoder(tokenize_mat, None, if_embedding=False, if_sequence=False).squeeze() #.unsqueeze(0) 

        print('text_features',text_features.shape) #text_features torch.Size([5000,512])
        text_features = text_features[:57600]
        
        text_features = text_features.reshape(1600,36,1024)
        
    saved_np = text_features.cpu().numpy()  
    
    np.save('nus_llm_multi_label_caption_selection_50_57600_one_cls_syn.npy', saved_np)    
    

def replicate_cls(cls_tensor, batch_size):
    # cls_tensor는 [cls, D] 형태의 텐서여야 합니다.
    cls_expanded = cls_tensor.unsqueeze(0).expand(batch_size, -1, -1)
    return cls_expanded


def get_eot_feature(input):  # [C,N,L,D]
    C, N, L, D = input.shape
    selected_indices = input.argmax(dim=2)
    # Create expanded indices for C, N, and D
    c_indices = torch.arange(C).unsqueeze(1).unsqueeze(2).to(selected_indices.device)
    n_indices = torch.arange(N).unsqueeze(0).unsqueeze(2).to(selected_indices.device)
    d_indices = torch.arange(D).unsqueeze(0).unsqueeze(1).to(selected_indices.device)
    features = input[c_indices, n_indices, selected_indices, d_indices]
    return features


def get_eot_feature_2d(input):  # [N, L, D]
    N, L, D = input.shape
    selected_indices = input.argmax(dim=1)  # Find the argmax indices along the L dimension
    # Create expanded indices for N and D
    n_indices = torch.arange(N).unsqueeze(1).to(selected_indices.device)
    d_indices = torch.arange(D).unsqueeze(0).to(selected_indices.device)
    features = input[n_indices, selected_indices, d_indices]
    return features


class DenseCLIP(nn.Module):
    def __init__(self, cfg, classnames, clip_model, device, return_interm_layers=False):
        super().__init__()
        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        # self.image_encoder = clip_model.visual

        self.model = clip_model
        self.return_interm_layers = return_interm_layers
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {"layer4": "0"}
        self.visual_encoder = IntermediateLayerGetter(self.model.visual, return_layers)
        self.positional_embedding = self.model.visual.attnpool.positional_embedding[1::]
        self.v_linear_weight = self.model.visual.attnpool.v_proj.weight
        self.v_linear_bias = self.model.visual.attnpool.v_proj.bias
        self.c_linear_weight = self.model.visual.attnpool.c_proj.weight
        self.c_linear_bias = self.model.visual.attnpool.c_proj.bias

        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.cfg = cfg

        ##TODO
        ####클래스별 prior knolwedge 받아오기 -> self.prior_dict
        # prior knowledge text - concept
      #  write_multi_space_caption_npy(self.text_encoder, device)
      #  write_caption_embedding_npy(self.text_encoder, device)
     #   print('complete coco')

        # write_voc_caption_npy(self.text_encoder, device)
        # print('complete write_voc_caption_npy')
        
        ### cvpr rebuttal low quality texts
       # write_caption_embedding_npy(self.text_encoder, device)
        
        print('start embedding generation')

        
       # write_coco_llm_low_quality_caption(self.text_encoder, device)
     #   write_voc_llm_low_quality_caption(self.text_encoder, device)
        #write_nus_llm_low_quality_caption(self.text_encoder, device)
        
        #After cvpr : class related desc filtering with syn 
        write_caption_voc_syn_npy(self.text_encoder, device)
      #  write_caption_syn_npy(self.text_encoder, device)
     #   write_caption_nus_syn_npy(self.text_encoder, device)
        
        
       # write_nus_llm_caption_no_random_npy(self.text_encoder, device)
        print('complete write_nus_llm_caption_no_random_npy')
       # write_nus_caption_npy(self.text_encoder, device)
       # write_nus_llm_caption_npy(self.text_encoder, device)
    
        
        
      #  print('complete write_nus_caption_npy')
     #   write_prior_npy('/data/CVPR23_medfm/DATA/priors/COCO1.json',self.text_encoder)
        self.prior_dict = read_npy('COCO1.npy', device)
        print('self.prior_dict:', self.prior_dict.shape)

        self.image_raw = json2list('/data/CVPR23_medfm/DATA/priors/COCO2.json', True)  # [80,50]
        #write_prior_npy('/data/CVPR23_medfm/DATA/priors/COCO2.json',self.text_encoder,True)
        self.image_dict = read_npy('COCO2.npy', device)
        print('self.image_dict:', self.image_dict.shape)

        self.bg_raw = json2list('/data/CVPR23_medfm/DATA/priors/COCO_bg.json')  # [1, 50]
        write_bg_npy('/data/CVPR23_medfm/DATA/priors/COCO_bg.json', self.text_encoder)
        self.bg_dict = read_npy('COCO_bg.npy', device)
        print('self.bg_dict:', self.bg_dict.shape)

        #### text aggregation layer 선언
        self.agg_text = MCA(dim=1024)

        #### img aggregation layer 선언
        self.agg_img = MCA(dim=1024)

        #### prototype 선언 -> self.prototype (mscoco- 80, VOC - 20,...)
        # a photo of [CLS name]
        # write_prototype_npy(self.text_encoder)
        self.prototype = read_npy('COCO_proto.npy', device)
        print('self.prototype:', self.prototype.shape)

        self.device = device

        # self.avg_pool = nn.AdaptiveAvgPool1d(1)

    def encode_image(self, x):
        def stem(x):
            for conv, bn in [(self.visual_encoder.conv1, self.visual_encoder.bn1), \
                             (self.visual_encoder.conv2, self.visual_encoder.bn2),
                             (self.visual_encoder.conv3, self.visual_encoder.bn3)]:
                x = self.visual_encoder.relu(bn(conv(x)))
            x = self.visual_encoder.avgpool(x)
            return x

        x = x.type(self.visual_encoder.conv1.weight.dtype)
        x = stem(x)
        x = self.visual_encoder.layer1(x)
        x = self.visual_encoder.layer2(x)
        x = self.visual_encoder.layer3(x)
        x = self.visual_encoder.layer4(x)
        return x

    def forward(self, image=None, captions=None, if_test=False, label=None):

        usage_text_att = 1
        usage_image_att = 1
        
       # print(captions)
        

        if if_test:

            if usage_image_att:
                proto_set = self.prototype  # [cls,L,D]
                proto_features = get_eot_feature_2d(proto_set)  # [cls,D]
                proto_feat = replicate_cls(proto_features, image.shape[0])  # [B,cls,D]

            image_feat = self.encode_image(image)
            b, c, h, w = image_feat.shape

            x = image_feat.reshape(b, c, h * w).permute(2, 0, 1)

            # x = x / x.norm(dim=-1, keepdim=True)
            # g_x = x.mean(0, keepdim=True)
            # x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW)xBxC

            x = F.linear(x, self.v_linear_weight, self.v_linear_bias)
            x = F.linear(x, self.c_linear_weight, self.c_linear_bias)

            image_features = x  # [49,100,1024] #[HW*B*C]
            image_feature_, _ = self.model.visual.attnpool(image_feat, if_pos=False)  # [100,1024] #[B*C]

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)

            if usage_image_att:
                proto_feat = proto_feat / proto_feat.norm(dim=-1, keepdim=True)

                local_feat = image_features.permute(1, 0, 2)  # [B,HW,D]
                # global_feat = image_feature_.unsqueeze(1) #[B,1,D]
                # images_global_local_feat = torch.cat((global_feat, local_feat), dim=1)

                images_global_local_feat = local_feat
                output = self.agg_img(x_q=images_global_local_feat, x_k=proto_feat, x_v=proto_feat)

                local_feat = local_feat + output  # [B,HW,D]
                # global_feat = global_feat + output[:, 0, :].unsqueeze(1)  # [B,1,D]

                image_features = local_feat.permute(1, 0, 2)  # [HW,B,D]
                # image_feature_ = global_feat[:,0,:] #[B,D]

            # ===============================================================

            prompts, prompts_double, temperature, spatial_T, rk_scale = self.prompt_learner()
            tokenized_prompts = self.tokenized_prompts
            text_features = self.text_encoder(prompts, tokenized_prompts)
            text_features_neg = self.text_encoder(prompts_double, tokenized_prompts)

            if usage_text_att:
                labeled_prior_set = self.prior_dict  # [cls,N,L,D]
                prior_features = get_eot_feature(labeled_prior_set)

                # text_features = text_features.unsqueeze(1)  # [cls,1,D]
                text_features_neg = text_features_neg.unsqueeze(1)  # [cls,1,D]

                global_local_feat = text_features_neg  # [cls,1,D]
                # global_local_feat = torch.cat((text_features, text_features_neg), dim=1)  # [cls,2,D]
                # print (prior_features.shape,global_local_feat.shape)

                output = self.agg_text(x_q=global_local_feat, x_k=prior_features, x_v=prior_features)  # [C,1,D]
                # text_features = text_features + output[:, 0, :].unsqueeze(1)  # [cls,1,D]
                text_features_neg = text_features_neg + output  # [cls,1,D]

                # text_features = text_features.squeeze(1)
                text_features_neg = text_features_neg.squeeze(1)

            image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_features_neg = text_features_neg / text_features_neg.norm(dim=-1, keepdim=True)

            logit_scale = temperature.exp()  # rk_scale
            logit_scale = logit_scale if self.cfg.TRAIN.IF_LEARN_SCALE else 4.0  # 50
            logits_ = logit_scale * image_feature_ @ text_features.t()  # B * C,  cls * C, = B * cls
            logits_neg = image_features @ text_features_neg.t()  # HW * B * C,  cls * C,  HW * B * cls

            # print ("image_features:",image_features.shape)
            # print ("image_feat,image_feature_,logits_:", image_feat.shape, image_feature_.shape,logits_.shape)

            tmp_scale = spatial_T.exp() if self.cfg.TRAIN.IF_LEARN_spatial_SCALE else self.cfg.TRAIN.spatial_SCALE_image  # 5 #
            prob_spatial = torch.nn.functional.softmax(logits_neg * tmp_scale, dim=0)
            logits_local = torch.sum(logit_scale * logits_neg * prob_spatial, dim=0)

            # print ("logits_neg.shape",logits_neg.shape)

            return logits_, logits_local, logits_neg, image_features @ text_features.t()  # compare additional branch with global proxy
        else:

            image_feat = self.text_encoder(captions, None, if_embedding=False, if_sequence=True)
            b, l, d = image_feat.shape
            global_caption = image_feat[torch.arange(image_feat.shape[0]), captions.argmax(dim=-1)]  # BD

            # print(label.shape)

            #### text conder
            ## Global prompt - prompts
            ## Local prompt - prompts_double

            ## prior knolwedge 불러오기
            proto_set = self.prototype  # [cls,L,D]
            proto_features = get_eot_feature_2d(proto_set)  # [cls,D]

            labeled_prior_set = self.prior_dict  # [cls,N,L,D]
            # print (labeled_prior_set.shape)

            prior_features = get_eot_feature(labeled_prior_set)  # [cls,N,D]

            ## text aggregation layer 에 통과 시키기
            prompts, prompts_double, temperature, spatial_T, rk_scale = self.prompt_learner()
            tokenized_prompts = self.tokenized_prompts
            text_features = self.text_encoder(prompts, tokenized_prompts)
            text_features_neg = self.text_encoder(prompts_double, tokenized_prompts)

            if usage_text_att:
                text_features = text_features.unsqueeze(1)  # [cls,1,D]
                text_features_neg = text_features_neg.unsqueeze(1)  # [cls,1,D]
                # global_local_feat = torch.cat((text_features, text_features_neg),dim=1) #[cls,2,D]

                global_local_feat = text_features_neg

                output = self.agg_text(x_q=global_local_feat, x_k=prior_features, x_v=prior_features)
                # [C,2,D]
                # text_features = text_features + output[:,0,:].unsqueeze(1) #[cls,1,D]
                text_features_neg = text_features_neg + output  # [cls,1,D]

                text_features = text_features.squeeze(1)
                text_features_neg = text_features_neg.squeeze(1)

            # print(text_features.shape, text_features_neg.shape)
            # print('success text encoder')
            #### img conder

            ##prior knolwedge for image 불러오기
            labeled_image_prior_set = self.image_dict  # [cls,N,L,D]
            # print(labeled_image_prior_set.shape)
            image_features = get_eot_feature(labeled_prior_set)  # [cls,N,D]

            bg_prior_set = self.bg_dict  # [N,L,D]
            # print(bg_prior_set.shape)
            bg_features = get_eot_feature_2d(bg_prior_set)  # [N,D]

            ##7x7 grid 생성
            images_batch = torch.empty(label.shape[0], 49, 1024).to(self.device)  # [B,HW,D]
            for idx, batch_label in enumerate(label):  # [64,80]
                num_class = np.sum(batch_label.detach().cpu().numpy())
                max_num = int(39 / num_class)

                label_candidates = torch.where(batch_label == 1)[0].detach().cpu().numpy()
                class_per_num = []
                add_candidates = {}
                for i in range(num_class):
                    #allocation = np.random.randint(1, max_num + 1)
                    allocation = max_num
                    class_per_num.append(allocation)
                    add_candidates[label_candidates[i]] = allocation

                total_num = sum(class_per_num)
                add_background_feature = 49 - total_num

                draw = 49
                write_idx = 0
                while draw:
                    # print(idx,draw)
                    selected_cls = np.random.randint(0, num_class + 1)  # 0:bg, 1~num_class:cls
                    if selected_cls == 0:
                        if add_background_feature:
                            selected_idx = np.random.randint(0, bg_features.shape[0])
                            images_batch[idx, write_idx] = bg_features[selected_idx].clone()
                            draw = draw - 1
                            write_idx += 1
                            add_background_feature = add_background_feature - 1
                    else:
                        selected_real_cls = list(add_candidates)[selected_cls - 1]
                        if add_candidates[selected_real_cls] != 0:
                            selected_idx = np.random.randint(0, image_features[selected_real_cls].shape[0])
                            images_batch[idx, write_idx] = image_features[selected_real_cls][selected_idx].clone()
                            draw = draw - 1
                            write_idx += 1
                            add_candidates[selected_real_cls] = add_candidates[selected_real_cls] - 1
            # print('make images done.')
            # images_batch # [B,HW,D]
            F_ = images_batch
            f = global_caption  # [B,D]
            image_feature_global = f.unsqueeze(1)  # [B,1,D]
            image_feature_local = F_  # [B,HW,D]

            if usage_image_att:
                # images_global_local_feat = torch.cat((image_feature_global, image_feature_local), dim=1)  # [B,1+HW,D]

                images_global_local_feat = image_feature_local

                proto_feat = replicate_cls(proto_features, label.shape[0])  # [B,cls,D]

                proto_feat = proto_feat / proto_feat.norm(dim=-1, keepdim=True)
                output = self.agg_img(x_q=images_global_local_feat, x_k=proto_feat, x_v=proto_feat)

                # image_feature_global = image_feature_global + output[:,0,:].unsqueeze(1) # [B,1,D]
                image_feature_local = image_feature_local + output  # [B,HW,D]

                # image_feature_global = image_feature_global.squeeze(1)
                image_feature_global = image_feature_global.squeeze(1)

            image_feature_ = image_feature_global
            image_features = image_feature_local
            # print (image_feature_.shape, image_features.shape)


            image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_features_neg = text_features_neg / text_features_neg.norm(dim=-1, keepdim=True)

            # mask irrelavent tokens
            # text_mask = (captions == 0).long() * (-10000)  # BL

            logit_scale = temperature.exp()  # rk_scale
            logit_scale = logit_scale if self.cfg.TRAIN.IF_LEARN_SCALE else 4.0  # 50 # temperature.exp()  # self.logit_scale.exp()
            # print (image_feature_.shape, text_features.shape)
            logits_ = logit_scale * image_feature_ @ text_features.t()  # B * C,  cls * C, = B * cls
            logits_neg = image_features.permute(1, 0, 2) @ text_features_neg.t()  # L * B * C,  cls * C =  L * B * cls
            # logits_neg = logits_neg.permute(2, 1, 0) + text_mask[None, :, :] #  cls*B*L
            # logits_neg = logits_neg.permute(2, 1, 0)

            tmp_scale = spatial_T.exp() if self.cfg.TRAIN.IF_LEARN_spatial_SCALE else self.cfg.TRAIN.spatial_SCALE_text
            prob_spatial = torch.nn.functional.softmax(logits_neg * tmp_scale, dim=0)
            # logits_neg : torch.Size([49, 512, 80]) => prob_spatial : torch.Size([49, 512, 80])

            logits_local = torch.sum(logit_scale * logits_neg * prob_spatial, dim=0)
            # logits_neg : torch.Size([49, 512, 80]) * prob_spatial : torch.Size([49, 512, 80]) => [512,80]
            # [B*cls]

            return logits_, logits_local, image_features, text_features


##TODO - cross attention
class MCA(nn.Module):
    def __init__(self, dim=192, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self._reset_parameters()

    def _reset_parameters(self):
        torch.manual_seed(0)
        nn.init.xavier_uniform_(self.q.weight)
        nn.init.xavier_uniform_(self.k.weight)
        nn.init.xavier_uniform_(self.v.weight)
        nn.init.xavier_uniform_(self.proj.weight)
        if self.k.bias is not None:
            nn.init.xavier_normal_(self.k.bias)
        if self.v.bias is not None:
            nn.init.xavier_normal_(self.v.bias)
        if self.proj.bias is not None:
            nn.init.constant_(self.proj.bias, 0.)

    def forward(self, x_q, x_k, x_v):
        B, N_q, C = x_q.shape
        _, N_kv, C = x_k.shape
        _, N_kv, C = x_v.shape

        # b, h, n, d
        q = self.q(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(x_k).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x_v).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        # [b, h, n, d] * [b, h, d, m] -> [b, h, n, m]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N_q, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# kl_loss = nn.KLDivLoss(reduction="batchmean")
# ce_loss = torch.nn.CrossEntropyLoss()

@TRAINER_REGISTRY.register()
class Caption_ours(TrainerX):
    def model_inference(self, input):
        return self.model(input, if_test=True)
        # return self.model(None, input)

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")

        # images = []
        # labels = []
        # outputs = []
        # output_poss = []
        # spatial_logits = []
        # global_spatial_logits = []
        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            # output = self.model_inference(input)
            output, output_pos, image_features_, text_features_ = self.model_inference(input)
            # visualize_with_spatial_logits(batch_idx,input, image_features_, label, output_pos)

            # if batch_idx == 10:
            #    exit(-1)

            self.evaluator.process(output, label, output_pos)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]

    def check_cfg(self, cfg):
        assert cfg.TRAINER.Caption.PREC in ["fp16", "fp32", "amp"]

    def build_model(self):
        print('==================== Building model in Caption_distill_double ======================')
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)

        if cfg.TRAINER.Caption.PREC == "fp32" or cfg.TRAINER.Caption.PREC == "amp":
            # CLIP's default precision is fp16
            clip_model.float()

        print("Building custom CLIP")
        # self.model = CustomCLIP(cfg, classnames, clip_model)
        self.model = DenseCLIP(cfg, classnames, clip_model, self.device)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)

        # load_pretrained_weights(self.model.prompt_learner, 'output/voc2007_caption_distill_abinf/Caption_distill_double/rn50_fixscale/nctx16_cscFalse_ctpend/seed3/prompt_learner/model-best.pth.tar')
        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)

        self.model.to(self.device)
        # NOTE: only give prompt_learner to the optimizer
        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM, self.model.agg_img, self.model.agg_text)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)

        self.scaler = GradScaler() if cfg.TRAINER.Caption.PREC == "amp" else None

        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        device_count = torch.cuda.device_count()
        if device_count > 1:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)

    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)


        prec = self.cfg.TRAINER.Caption.PREC
        if prec == "amp":
            with autocast():
                output, output_local, _, _ = self.model(image)
                loss = F.cross_entropy(output, label)
            self.optim.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optim)
            self.scaler.update()
        else:
            output, output_local, _, _ = self.model(None, image, False, label)
            if self.cfg.TRAIN.LOSSFUNC == 'sigmoid':
                loss = norm_logits_BCEloss(output, label.float()) + norm_logits_BCEloss(output_local, label.float())
            elif self.cfg.TRAIN.LOSSFUNC == 'focal':
                loss = sigmoid_focal_loss(output, label)
            elif self.cfg.TRAIN.LOSSFUNC == 'asl':
                loss = ASL_loss(output, label) + ASL_loss(output_local, label.float())
            elif self.cfg.TRAIN.LOSSFUNC == 'ranking':
                loss = ranking_loss(output, label)
            elif self.cfg.TRAIN.LOSSFUNC == 'double_ranking':
                # loss = ranking_loss(output, label, scale_ = 1.0, margin_ = 1) + ranking_loss(output_local, label, scale_ = 1.0, margin_ = 1)
                a = ranking_loss(output, label, scale_=1.0, margin_=1)
                b = ranking_loss(output_local, label, scale_=1.0, margin_=1)
                print(a, b)
                scale_a = 2.0
                
                loss = a*scale_a + b



            else:
                loss = soft_cross_entropy(output, label)

            self.model_backward_and_update(loss)

        loss_summary = {
            "loss": loss.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    def load_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # By default, the best model is loaded
        model_file = "model-best.pth.tar"

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
            # set strict=False
            self._models[name].load_state_dict(state_dict, strict=False)
            print(state_dict)
