import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.model import convert_weights

from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
from tqdm import tqdm
import numpy as np
import math
import time

CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}."
}
label_names = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
    "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
    "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
    "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
    "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
    "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
    "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
    "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
    "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
    "hair drier", "toothbrush"
]

MEASURE_TIME = False
WTA = True
RESNET101 = True
TEMP1 = 40.0
TEMP2 = 30.0
DATASET = "NUSWIDE"

TASK_ADAPT_LOCAL = True


@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()

        return logits, None, None


def read_npy(numpy_file, device):
    all_features = np.load(numpy_file, allow_pickle=True)

    # for ii, l_n in enumerate(label_names):
    #     temp = torch.tensor(all_features[ii])
    #     if ii == 0:
    #         text_caption = temp[:16].unsqueeze(0)
    #     else:
    #         text_caption = torch.cat((text_caption, temp[:16].unsqueeze(0)),0) # 80 16 512

    print('load complete:', numpy_file)
    return torch.tensor(all_features).to(torch.float16).to(
        device)  # text_caption.to(device) #torch.tensor(all_features).to(device)


@TRAINER_REGISTRY.register()
class ZeroshotCLIP_dense(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

        self.visual_encoder = IntermediateLayerGetter(self.clip_model.visual, {"layer4": "0"})
        self.positional_embedding = self.clip_model.visual.attnpool.positional_embedding[1::]
        self.v_linear_weight = self.clip_model.visual.attnpool.v_proj.weight
        self.v_linear_bias = self.clip_model.visual.attnpool.v_proj.bias
        self.c_linear_weight = self.clip_model.visual.attnpool.c_proj.weight
        self.c_linear_bias = self.clip_model.visual.attnpool.c_proj.bias

        if DATASET == "COCO":
            if RESNET101:
                self.caption = read_npy('COCO_caption_selection_101.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self.caption_local = read_npy('COCO_caption_selection_local_101.npy', self.device)
                self.caption_local = self.caption_local / self.caption_local.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('COCO_multi_label_caption_selection_101_v3.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
            else:
                self.caption = read_npy('COCO_caption_selection.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self.caption_local = read_npy('COCO_caption_selection_local.npy', self.device)
                self.caption_local = self.caption_local / self.caption_local.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('COCO_multi_label_caption_selection_50_v3.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)

        elif DATASET == "NUSWIDE":
            if RESNET101:
                self.caption = read_npy('nus_llm_caption_selection_101_27_10000.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('nus_llm_multi_label_caption_selection_101_10000.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
                # nus_multi_label_caption_selection_101_v3_7000
            else:
                self.caption = read_npy('nus_llm_caption_selection_50_27_10000.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('nus_llm_caption_selection_50_27_10000.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)


        elif DATASET == "VOC":
            if RESNET101:
                self.caption = read_npy('VOC_caption_selection_100_101.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('VOC_multi_label_caption_selection_101_v5_40000.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
            # self._multi_label_caption = self._multi_label_caption.permute(1,0,2) 200 100 512
            else:
                self.caption = read_npy('VOC_caption_selection.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('VOC_multi_label_caption_selection_50_v3.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
            # self._multi_label_caption = self._multi_label_caption.permute(1,0,2)#  .view(100,200,-1)

    def encode_image(self, x):
        def stem(x):
            for conv, bn in [(self.visual_encoder.conv1, self.visual_encoder.bn1), \
                             (self.visual_encoder.conv2, self.visual_encoder.bn2),
                             (self.visual_encoder.conv3, self.visual_encoder.bn3)]:
                x = self.visual_encoder.relu(bn(conv(x)))
            x = self.visual_encoder.avgpool(x)
            return x

        x = x.type(self.visual_encoder.conv1.weight.dtype)
        x = stem(x)
        x = self.visual_encoder.layer1(x)
        x = self.visual_encoder.layer2(x)
        x = self.visual_encoder.layer3(x)
        x = self.visual_encoder.layer4(x)
        return x

    def model_inference(self, image, label):
        image_feat = self.encode_image(image)
        b, c, h, w = image_feat.shape
        x = image_feat.reshape(b, c, h * w).permute(2, 0, 1)

        x = F.linear(x, self.v_linear_weight, self.v_linear_bias)
        x = F.linear(x, self.c_linear_weight, self.c_linear_bias)
        image_features = x

        image_feature_, _ = self.clip_model.visual.attnpool(image_feat)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)

        def get_global_p_three_step(caption, image_feature_,
                                    temp):  # caption: [10, 10, 70, 512], image_feature_: [100, 512]
            # Step 1: Calculate similarity for the 70 axis and apply softmax
            # Compute similarity across the 70 dimension
            similarity = temp * torch.einsum('ijkc,cb->ijkb', caption, image_feature_.t())  # [10,10, 70, 100]
            attention_weights = F.softmax(similarity, dim=2)  # [10,10, 70, 100] #k axis

            # Compute the weighted sum across the 70 dimension

            # torch.Size([10, 10, 70, 100]) torch.Size([10, 10, 70, 512])
            weighted_sum = torch.einsum('ijbk,ijkc->ijbc', attention_weights.permute(0, 1, 3, 2),
                                        caption)  # [10,10,100,512]

            # Step 2: Calculate similarity for the new axis (100) and apply softmax
            # Compute similarity across the 100 dimension
            # torch.Size([10, 10, 100, 512])
            weighted_sum = weighted_sum.permute(2, 0, 1, 3)  # [100,10,10,512]
            # [100,10,10,512]
            similarity_2 = temp * torch.einsum('bijc,bc->bij', weighted_sum, image_feature_)  # [100,10,10]
            attention_weights_2 = F.softmax(similarity_2, dim=2)  # [100,10,10] #j axis
            # torch.Size([100, 10, 10]) torch.Size([100, 10, 10, 512])
            weighted_sum_2 = torch.einsum('bij,bijc->bic', attention_weights_2, weighted_sum)  # [100,10,512]

            # Step 3: Calculate similarity for the first axis (10) and apply softmax
            # Compute similarity across the 10 dimension (batch dimension)
            similarity_3 = temp * torch.einsum('bic,bc->bi', weighted_sum_2, image_feature_)  # [100,10]
            attention_weights_3 = F.softmax(similarity_3, dim=1)  # [100, 10]

            # Compute the final weighted sum across the 10 dimension
            # [100,10] #[100,10,512]
            global_p = torch.einsum('bi,bic->bc', attention_weights_3, weighted_sum_2)  # [100,512]

            return global_p

        # set ideal dimension and reshape!!
        dim1, dim2, depth = self._multi_label_caption.shape
        dim1_1 = int(math.sqrt(dim1))

        #   dim1_1 = 40
        #   dim2 = 25

        multi_step_input = self._multi_label_caption.reshape(dim1_1, dim1_1, dim2, depth)
        global_p = get_global_p_three_step(multi_step_input, image_feature_, TEMP1)

        #  image_feature_ = global_p

        image_feature_temp = global_p  # (0.5*image_feature_ + 0.5* global_p) / 1.0
        # caption_logit = torch.matmul(self.caption, image_feature_temp.t())
        A = torch.matmul(self.caption, image_feature_temp.t())  # 80 16 100
        A_up = F.softmax(A, dim=1)

        # caption_logit = torch.matmul(self.caption,image_feature_.t())
        # A = torch.matmul(self.caption, image_feature_.t()) #80 16 100
        # A_up = F.softmax(A, dim=1) #80 16 100

        P_up = torch.matmul(A_up.permute(0, 2, 1), self.caption)  # 80 100 512

        logit_scale = self.clip_model.logit_scale.exp()  # logit_scale = self.clip_model.logit_scale.exp()

        MAX_GRID_SIZE = 14
        
        our_max_logit = torch.empty(MAX_GRID_SIZE, MAX_GRID_SIZE, b).to(
                        self.device)  # 2칸씩 이동하므로 7x7에 저장가능
        our_max_index = torch.empty(MAX_GRID_SIZE, MAX_GRID_SIZE, b).to(self.device)
        # torch.Size([14, 14, 100])
        

        if TASK_ADAPT_LOCAL:
            
            
            image_feat = image_feat.permute(2,3,0,1) #[H,W,B,C]

            local_logit_list = []

            for row in range(MAX_GRID_SIZE):
                for col in range(MAX_GRID_SIZE):

                    min_idx = 0
                    max_idx = MAX_GRID_SIZE - 1

                    pooling_list = []
                    pooling_id = []
                    center_list = []

                    # 1x1 (1)
                    pooling_list.append(image_feat[row, col])
                    pooling_id.append(0)

                    # 2x2 (1)
                    if row + 1 <= max_idx and col + 1 <= max_idx:
                        pooling_list.append(image_feat[row:row + 2, col:col + 2])
                        pooling_id.append(1)
                        center_list.append(image_feat[row, col])

                    # 2x4 (2)
                    if col - 2 >= min_idx:
                        pooling_list.append(image_feat[row:row + 2, col - 2:col + 2])
                        pooling_id.append(2)
                        center_list.append(image_feat[row, col])

                    # 2x4 (3)
                    if row + 3 <= max_idx:
                        pooling_list.append(image_feat[row:row + 2, col:col + 4])
                        pooling_id.append(3)
                        center_list.append(image_feat[row, col])

                    # 4x2 (4)
                    if row - 2 >= min_idx:
                        pooling_list.append(image_feat[row - 2:row + 2, col:col + 2])
                        pooling_id.append(4)
                        center_list.append(image_feat[row, col])

                    # 4x2 (5)
                    if row + 3 <= max_idx:
                        pooling_list.append(image_feat[row:row + 4, col:col + 2])
                        pooling_id.append(5)
                        center_list.append(image_feat[row, col])

                    # 4x4 (6)
                    if row - 2 >= min_idx and col - 2 >= min_idx:
                        pooling_list.append(image_feat[row - 2:row + 2, col - 2:col + 2])
                        pooling_id.append(6)
                        center_list.append(image_feat[row, col])
                    
                    # 4x4 (7)
                    if row - 2 >= min_idx and col + 3 <= max_idx:
                        pooling_list.append(image_feat[row - 2:row + 2, col:col + 4])
                        pooling_id.append(7)
                        center_list.append(image_feat[row, col])

                    # 4x4 (8)
                    if row + 3 <= max_idx and col - 2 >= min_idx:
                        pooling_list.append(image_feat[row:row + 4, col - 2:col + 2])
                        pooling_id.append(8)
                        center_list.append(image_feat[row, col])

                    # 4x4 (9)
                    if row + 3 <= max_idx and col + 3 <= max_idx:
                        pooling_list.append(image_feat[row:row + 4, col:col + 4])
                        pooling_id.append(9)
                        center_list.append(image_feat[row, col])

                    # 6x6 (10)
                    if row - 2 >= min_idx and row + 3 <= max_idx and col - 2 >= min_idx and col + 3 <= max_idx:
                        pooling_list.append(image_feat[row - 2:row + 4, col - 2:col + 4])
                        pooling_id.append(10)
                        center_list.append(image_feat[row, col])
                    
                    logit_list = []

                    for pool_idx, pooling_source in enumerate(pooling_list):

                        if pool_idx == 0:
                            pooling_source = F.linear(pooling_source, self.v_linear_weight, self.v_linear_bias)
                            patch_pooled = F.linear(pooling_source, self.c_linear_weight, self.c_linear_bias)
                        elif pooling_source.dim() == 3:
                            pooling_source = pooling_source.unsqueeze(0).permute(2, 3, 0, 1).clone()
                        else:
                            pooling_source = pooling_source.permute(2, 3, 0, 1).clone()

                        if pool_idx != 0:
                            patch_pooled, _ = self.clip_model.visual.attnpool(pooling_source, True, True,
                                                                              center_list[pool_idx - 1])

                        local_patch = patch_pooled / patch_pooled.norm(dim=-1, keepdim=True)

                        updated_img_feat = get_global_p_three_step(multi_step_input, local_patch, TEMP2)
                        # [80,16,512] [100,512]

                        logit_scale = self.clip_model.logit_scale.exp()

                        A = self.caption @ updated_img_feat.t()  # [80,16,100]
                        A_up = F.softmax(A * 5, dim=1)  # [80,16,100]
                        # P_up =  torch.matmul(A_up.permute(0, 2, 1), self.caption) #[80,100,512]

                        caption_logit_ = self.caption @ updated_img_feat.t()  # [80,16,512] [512,100] =>[80,16,100]
                        caption_logit = logit_scale * torch.sum(A_up * caption_logit_, dim=1)
                        caption_logit = caption_logit.t()


                        raw_logit = logit_scale * local_patch @ self.text_features.t()  # [B ,cls] = [B * C][cls * C]
                        """
                        A = torch.matmul(self.caption, patch_pooled.t())  # [cls,16,C] [C,B] = [cls,16,B]
                        A_up = F.softmax(A, dim=1)  # [cls,16,B]
                        caption_logit = logit_scale * torch.sum(A_up * A, dim=1)  # [cls,B]
                        caption_logit = caption_logit.t()  # [B,cls]
                        """
                        
                        #print (raw_logit.shape, caption_logit.shape)
                        
                        patch_logit = (raw_logit + caption_logit) / 2

                        # print (patch_logit.shape, patch_entropy.shape)

                        logit_list.append(patch_logit)

                    max_values_list = []
                    max_indices_list = []

                    logit_list = torch.stack(logit_list, dim=0)  # [4, 100, 80]

                    for logit_ in logit_list:  # [100,80]
                        value, indices = torch.max(logit_, dim=1)  # [100]
                        max_values_list.append(value)
                        max_indices_list.append(indices)

                        # print (value,indices)

                    # 4개의 예측 중 최대 logit 값을 스택하여 텐서로 변환
                    stacked_max_values = torch.stack(max_values_list, dim=0)
                    stacked_max_indices = torch.stack(max_indices_list, dim=0)

                    # 각 배치에서 4개의 예측 중 최대 logit 값을 찾음
                    # print(stacked_max_values)
                    final_max_values, indices_at_max = torch.max(stacked_max_values, dim=0)
                    # print (stacked_max_values.shape)

                    final_idx = stacked_max_indices[indices_at_max, torch.arange(stacked_max_indices.size(1))]
                    final_logit = final_max_values

                    # print (final_idx.shape, final_logit.shape)

                    # print (row,col, torch.isnan(final_logit).any())
                    # print(row, col, torch.isnan(final_idx).any())

                    our_max_logit[row][col] = final_logit  # [100]
                    our_max_index[row][col] = final_idx  # [100]

            our_max_logit = our_max_logit.reshape(-1, b)  # [49,100]
            our_max_index = our_max_index.reshape(-1, b)  # [49,100]

            our_max_logit = our_max_logit.permute(1, 0)  # [100,49]
            our_max_index = our_max_index.permute(1, 0)  # [100,49]

            num_classes = label.shape[1]
            result = torch.zeros((MAX_GRID_SIZE * MAX_GRID_SIZE, b, num_classes),
                                 device=our_max_logit.device)  # 모든 값을 0으로 초기화
            
            
            our_max_index = our_max_index.long()
            our_max_logit = our_max_logit.float()
            
            
            
            # Create a range for each batch element
            batch_range = torch.arange(b, device=our_max_logit.device)

            # Expand our_max_index to have the same size as batch_range for broadcasting
            batch_range = batch_range.unsqueeze(1).expand_as(our_max_index)

            # Flatten the indices since PyTorch does not support multi-dimensional indexing
            flat_result_indices = (batch_range * num_classes + our_max_index).flatten()

            # Similarly, flatten our_max_logit for indexing
            flat_our_max_logit = our_max_logit.flatten()

            # Create a result tensor filled with zeros
            result = torch.zeros(14 * 14, b, num_classes, device=our_max_logit.device)

            # Assign our_max_logit values to the result tensor using computed flat indices
            # We will index into the flattened view of the result tensor
            result.view(-1)[flat_result_indices] = flat_our_max_logit
            """
            # 각 행에 대해 our_max_index의 값을 사용하여
            # 해당 위치의 our_max_logit 값을 결과 tensor에 할당합니다.
            for j in range(14 * 14):  # our_max_index의 행 크기
                for i in range(b):  # 배치 크기
                    class_idx = int(our_max_index[i, j].item())
                    # print (class_idx)
                    if 0 <= class_idx < num_classes:  # 해당 클래스 인덱스가 유효한 범위에 있으면
                        result[j, i, class_idx] = our_max_logit[i, j]
            """
            #logits = torch.sum(result * torch.softmax(result, dim=0), dim=0)  # [49,B,cls] *[49,B,cls]
            logits = result
        else:

            caption_logit_ = torch.einsum('hbd,dcn->hbcn', self.caption, image_features.permute(2, 0, 1))
            caption_logit_ = caption_logit_.permute(2, 3, 0, 1)

            A_ = caption_logit_  # torch.matmul(self.caption, image_feature_.t()) #80 16 100
            A_up_ = F.softmax(A_ * 5, dim=3)

            logits = logit_scale * (
                    torch.sum(A_up_ * caption_logit_, dim=3) + image_features @ self.text_features.t()) / 2.0

        logits_ = logit_scale * (torch.sum(P_up * image_feature_temp.unsqueeze(0), dim=2).permute(1,
                                                                                                  0) + image_feature_ @ self.text_features.t()) / 2.0

        a = (label[0, :] == 1).nonzero(as_tuple=False)

        # Original local logits [14X14,100,80]
        if WTA == False:
            prob_spatial = torch.nn.functional.softmax(logits, dim=0)
            logits = torch.sum(logits * prob_spatial, dim=0)

        else:
            # WTA
            grid_size = int(math.sqrt(logits.shape[0]))
            bs = logits.shape[1]
            num_classes = logits.shape[2]
            wta_result = torch.zeros((grid_size * grid_size, bs, num_classes), device=logits.device)  # 모든 값을 0으로 초기화
            local_logits = logits.reshape(grid_size, grid_size, bs, num_classes)  # [[14,14,100,80]

            # Flatten the first two dimensions of local_logits so that we can work with each patch individually
            # Now each patch is a separate element in the zeroth dimension
            flat_logits = local_logits.view(-1, *local_logits.shape[-2:])  # Shape [196, 100, 80]

            # Find the max values in each patch
            max_values, _ = flat_logits.max(dim=-1, keepdim=True)  # Shape [196, 100, 1]

            # Use torch.where to keep max values and set others to -1
            wta_result = torch.where(flat_logits == max_values, flat_logits, torch.full_like(flat_logits, -1.0))

            ## strong WTA

            result = wta_result.permute(1, 0, 2)  # [100, 196 ,80]

            # Create a mask for values that are not -1
            mask = result != -1

            # Use the mask to replace all -1 with the minimum float value
            # because -1 could be a valid maximum when dealing with negative values
            result_masked = torch.where(mask, result, torch.full_like(result, float('-inf')))

            # Now, instead of looping, use torch.max to find the maximum value across the grid
            # This will ignore the -inf values and effectively find the maximum non--1 value
            strong_wta, _ = torch.max(result_masked, dim=1)  # [100, 80]

            # Replace the -inf back to -1 if there was no valid maximum found
            strong_wta[strong_wta == float('-inf')] = -1.0

            # 결과를 logits에 저장
            logits = strong_wta

        return logits_, logits, None, None

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")

        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            # output = self.model_inference(input)

            if MEASURE_TIME:
                st = time.time()

            output, output_pos, image_features_, text_features_ = self.model_inference(input, label)

            if MEASURE_TIME:
                ed = time.time()
                print(ed - st)

            self.evaluator.process(output, label, output_pos)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]


@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
    """Prompt ensembling."""

    # templates = IMAGENET_TEMPLATES
    templates = IMAGENET_TEMPLATES_SELECT

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        for params in clip_model.parameters():
            params.requires_grad_(False)

        # add custom-made prompt
        if cfg.DATASET.NAME != "ImageNet":
            self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]

        num_temp = len(self.templates)
        print(f"Prompt ensembling (n={num_temp})")

        mean_text_features = 0
        for i, temp in enumerate(self.templates):
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
            prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            mean_text_features = mean_text_features + text_features
        mean_text_features = mean_text_features / num_temp
        mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)

        self.text_features = mean_text_features
        self.clip_model = clip_model
