import os.path as osp

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from torchvision.models._utils import IntermediateLayerGetter

from tqdm import tqdm
import pickle5 as pickle

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

from .utils import soft_cross_entropy, softmax_sigmoid_BCEloss, \
    norm_logits_BCEloss, sigmoid_focal_loss, sigmoid_ASL_loss, ranking_loss, ASL_loss
_tokenizer = _Tokenizer()

import os
import matplotlib.pyplot as plt
import numpy as np
import json
import gc

def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
        self.token_embedding = clip_model.token_embedding

    def forward(self, prompts, tokenized_prompts, if_embedding=True, if_sequence=False):
        if not if_embedding:
            tokenized_prompts = prompts
            prompts = self.token_embedding(prompts).type(self.dtype)  # [batch_size, n_ctx, d_model]
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        
        if if_sequence:
            x = x @ self.text_projection  # NLD * Dd = NLd
            return x
        else:
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            # ND * Dd = Nd
            x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
            return x


class PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.Caption.N_CTX
        ctx_init = cfg.TRAINER.Caption.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if cfg.TRAINER.Caption.CSC:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            
            if cfg.TRAINER.Caption.CSC:
                print("Initializing class-specific double contexts")
                ctx_vectors_double = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors_double = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors_double, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f'Initial double context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized
        self.ctx_double = nn.Parameter(ctx_vectors_double)  # to be optimized
        
        temperature = torch.tensor(3.0, dtype=dtype)  #  exp(3.91) = 50
        self.temperature = nn.Parameter(temperature)
        spatial_T = torch.tensor(3.0, dtype=dtype)  # 20
        self.spatial_T = nn.Parameter(spatial_T)
        ranking_scale = torch.tensor(4.0, dtype=dtype)  # 20
        self.ranking_scale = nn.Parameter(ranking_scale)

        # sigmoid_shift = torch.tensor(0.25, dtype=dtype)
        # self.sigmoid_shift = nn.Parameter(sigmoid_shift)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS
        
        # class agnostic token suffix
        prompts_nocls = [prompt_prefix + "."] * len(classnames)
        tokenized_prompts_nocls = torch.cat([clip.tokenize(p) for p in prompts_nocls])
        with torch.no_grad():
            embedding_nocls = clip_model.token_embedding(tokenized_prompts_nocls).type(dtype)
        self.register_buffer("token_suffix_nocls", embedding_nocls[:, 1 + n_ctx :, :])  # EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = cfg.TRAINER.Caption.CLASS_TOKEN_POSITION

    def forward(self, neg_prompt_wcls=True):
        """
        Returns current learned ctx embeddings, concated with cls word embeddings.
        """
        ctx = self.ctx
        ctx_double = self.ctx_double
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        if ctx_double.dim() == 2:
            ctx_double = ctx_double.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix
        suffix_nocls = self.token_suffix_nocls

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )
            if neg_prompt_wcls:
                prompts_neg = torch.cat(
                    [
                        prefix,  # (n_cls, 1, dim)
                        ctx_double,     # (n_cls, n_ctx, dim)
                        suffix,  # (n_cls, *, dim)
                    ],
                    dim=1,
                )
            else:
                prompts_neg = torch.cat(
                    [
                        prefix,  # (n_cls, 1, dim)
                        ctx_double,     # (n_cls, n_ctx, dim)
                        suffix_nocls,  # (n_cls, *, dim)
                    ],
                    dim=1,
                )


        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts, prompts_neg, self.temperature, self.spatial_T, self.ranking_scale
    
    
def visualize_with_spatial_logits(batch_idx,image, spatial_logits, label, probs, save_path='./tai_overlay_images_local'):
    
    
    #print ("probs",probs.shape)
    label_names = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
    "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
    "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
    "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
    "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
    "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
    "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
    "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
    "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
    "hair drier", "toothbrush"
    ]
    
    cell_size = 7
    
    probs = probs.detach().cpu().numpy() # [100,80]
    image = image.detach().cpu().numpy()  # [100, 3, 224, 224]

    # Assuming spatial_logits has a shape of [7x7, 100, 80]
    label = label.detach().cpu().numpy()
    spatial_logits = spatial_logits.detach().cpu().numpy()  
    spatial_logits = spatial_logits.reshape(cell_size, cell_size, image.shape[0], -1) # [7, 7, 100, 80]
    spatial_logits = spatial_logits.transpose((2, 3, 0, 1))# Transpose spatial_logits to [100, 80, 7, 7]

    # Convert spatial_logits to a PyTorch tensor
    spatial_logits = torch.tensor(spatial_logits, dtype=torch.float32)

    # Resize spatial_logits to match the dimensions of image
    resized_spatial_logits = F.interpolate(spatial_logits, size=(image.shape[2], image.shape[3]), mode='bilinear', align_corners=False) # Transpose spatial_logits to [100, 80, 224,224]

    # Convert resized_spatial_logits back to a NumPy array
    resized_spatial_logits = resized_spatial_logits.numpy()

    # Create a directory to save overlay images
    if not os.path.exists(save_path+"/label"):
        os.makedirs(save_path+"/label")
    if not os.path.exists(save_path+"/top5"):
        os.makedirs(save_path+"/top5")

    # Visualize the overlay for each image in the batch
    # process labels
    
    for i in range(image.shape[0]):
        label_indexs = (label[i] == 1).nonzero()[0]
        print(label_indexs)

        for label_idx in label_indexs:
            fig, axes = plt.subplots(1, 2, figsize=(10, 10))
            clipped_image = np.clip(image[i], 0, 1)
            axes[0].imshow(np.transpose(clipped_image, (1, 2, 0)))
            
            axes[1].imshow(np.transpose(clipped_image, (1, 2, 0)))
            im = axes[1].imshow(resized_spatial_logits[i, label_idx], cmap='jet', alpha=0.5)  # Overlay attention map with transparency
            fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
            axes[1].set_title(f'Prob:{probs[i][label_idx]} /Label {label_names[label_idx]} with PosLogits (Resized)')

            # Save the overlay image
            overlay_image_path = os.path.join(save_path+"/label", f"batch_{batch_idx}_image_{i}_label_{label_names[label_idx]}.png")
            plt.savefig(overlay_image_path, format='png')

            plt.close(fig)
            
    # process top5
    for i in range(image.shape[0]):
        top_5_indices = probs[i].argsort()[-5:][::-1]
        
        for label_idx in top_5_indices:
            fig, axes = plt.subplots(1, 2, figsize=(10, 10))
            clipped_image = np.clip(image[i], 0, 1)
            axes[0].imshow(np.transpose(clipped_image, (1, 2, 0)))
            
            axes[1].imshow(np.transpose(clipped_image, (1, 2, 0)))
            im = axes[1].imshow(resized_spatial_logits[i, label_idx], cmap='jet', alpha=0.5)  # Overlay attention map with transparency
            fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
            axes[1].set_title(f'Prob:{probs[i][label_idx]} /{label_names[label_idx]} Label:{label[i][label_idx]} with PosLogits (Resized)')

            # Save the overlay image
            overlay_image_path = os.path.join(save_path+"/top5", f"batch_{batch_idx}_image_{i}_prob_{label_names[label_idx]}_{label[i][label_idx]}.png")
            plt.savefig(overlay_image_path, format='png')

            plt.close(fig)

            
def json2list(file ,withclass=False):
    priors_list = []
    
    # JSON 파일 읽기
    with open(file, 'r') as f:
        data = json.load(f)
    
    # 원하는 형태의 딕셔너리로 데이터 저장
    prior_dict = {}
    for class_name, requests in data.items():
        descriptions = []
        for request, description_list in requests.items():

            if withclass:
                new_description_list = []
                for data in description_list:
                    new_description_list.append(data+" "+class_name)

                descriptions.extend(new_description_list)
            else:
                descriptions.extend(description_list)
        prior_dict[class_name] = descriptions
    
    # 결과 확인
    for class_name, descriptions in prior_dict.items():
        priors_list.append(descriptions[:])
    
    priors_list = np.array(priors_list)
    return priors_list


def write_bg_npy(json_file,text_encoder):
    prior_text_origin = json2list(json_file) # [1,50]
    prompts = prior_text_origin[0]
    tokenized_prompts_list = [clip.tokenize(p) for p in prompts]
    tokenized_prompts = torch.stack(tokenized_prompts_list)
    tokenized_prompts = torch.squeeze(tokenized_prompts)
    print(tokenized_prompts.shape) #[50,77]
    with torch.no_grad():
        text_features = text_encoder(tokenized_prompts,None,if_embedding=False,if_sequence=True)
        print(text_features.shape) #[50,77,1024]
        
    out_name= json_file.split('/')[-1].split('.json')[0]
        
    np.save(out_name+'.npy', text_features.cpu().numpy())


def write_prior_npy(json_file,text_encoder,withclass=False):
    prior_text_origin = json2list(json_file,withclass) # [80,50]
    prompts = prior_text_origin.flatten() #[80X50]

    tokenized_prompts_list = [clip.tokenize(p) for p in prompts]
    tokenized_prompts = torch.stack(tokenized_prompts_list) 
    tokenized_prompts = torch.squeeze(tokenized_prompts)
    print(tokenized_prompts.shape) #[80x50,77]
    with torch.no_grad():
        text_features = text_encoder(tokenized_prompts,None,if_embedding=False,if_sequence=True)
        print(text_features.shape) #[80x50,77,1024]
        prior_dict = text_features.reshape(prior_text_origin.shape[0],-1,text_features.shape[-2],text_features.shape[-1])
        #[80,50,77,1024]
        
    out_name= json_file.split('/')[-1].split('.json')[0]
        
    np.save(out_name+'.npy', prior_dict.cpu().numpy())
    
def read_npy(numpy_file,device):
    all_features =np.load(numpy_file)
    print('load complete:',numpy_file)
    return torch.tensor(all_features).to(device)

def write_prototype_npy(text_encoder):
    
    label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
        ]
    
    proto_prior_list =[]
    for label in label_names:
        prototype_cls = 'a photo of '+label
        proto_prior_list.append(prototype_cls)

    proto_prior_list = np.array(proto_prior_list)

    tokenized_prompts_list = [clip.tokenize(p) for p in proto_prior_list]
    tokenized_prompts = torch.stack(tokenized_prompts_list) 
    tokenized_prompts = torch.squeeze(tokenized_prompts)
    print(tokenized_prompts.shape) #[80,77]
    with torch.no_grad():
        text_features = text_encoder(tokenized_prompts,None,if_embedding=False,if_sequence=True)
        print(text_features.shape) #[80,77,512]
        
    np.save('COCO_proto.npy', text_features.cpu().numpy())

def replicate_cls(cls_tensor, batch_size):
    # cls_tensor는 [cls, D] 형태의 텐서여야 합니다.
    cls_expanded = cls_tensor.unsqueeze(0).expand(batch_size, -1, -1)
    return cls_expanded

def get_eot_feature(input): #[C,N,L,D]
    C, N, L, D = input.shape
    selected_indices = input.argmax(dim=2)
    # Create expanded indices for C, N, and D
    c_indices = torch.arange(C).unsqueeze(1).unsqueeze(2).to(selected_indices.device)
    n_indices = torch.arange(N).unsqueeze(0).unsqueeze(2).to(selected_indices.device)
    d_indices = torch.arange(D).unsqueeze(0).unsqueeze(1).to(selected_indices.device)
    features = input[c_indices, n_indices, selected_indices, d_indices]
    return features

def get_eot_feature_2d(input):  # [N, L, D]
    N, L, D = input.shape
    selected_indices = input.argmax(dim=1)  # Find the argmax indices along the L dimension
    # Create expanded indices for N and D
    n_indices = torch.arange(N).unsqueeze(1).to(selected_indices.device)
    d_indices = torch.arange(D).unsqueeze(0).to(selected_indices.device)
    features = input[n_indices, selected_indices, d_indices]
    return features

class DenseCLIP(nn.Module):
    def __init__(self, cfg, classnames, clip_model, device, return_interm_layers=False):
        super().__init__()
        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        # self.image_encoder = clip_model.visual

        self.model = clip_model
        self.return_interm_layers = return_interm_layers
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {"layer4": "0"}
        self.visual_encoder = IntermediateLayerGetter(self.model.visual, return_layers)
        self.positional_embedding = self.model.visual.attnpool.positional_embedding[1::]
        self.v_linear_weight = self.model.visual.attnpool.v_proj.weight
        self.v_linear_bias = self.model.visual.attnpool.v_proj.bias
        self.c_linear_weight = self.model.visual.attnpool.c_proj.weight
        self.c_linear_bias = self.model.visual.attnpool.c_proj.bias
        
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.cfg = cfg
        
        ##TODO
        ####클래스별 prior knolwedge 받아오기 -> self.prior_dict
        #prior knowledge text - concept
        
        #write_prior_npy('/data/CVPR23_medfm/DATA/priors/COCO1.json',self.text_encoder)
        self.prior_dict = read_npy('COCO1.npy',device)
        print ('self.prior_dict:',self.prior_dict.shape)

        self.image_raw = json2list('/data/CVPR23_medfm/DATA/priors/COCO2.json',True) # [80,50]
        #write_prior_npy('/data/CVPR23_medfm/DATA/priors/COCO2.json',self.text_encoder,True)
        self.image_dict = read_npy('COCO2.npy',device)
        print ('self.image_dict:',self.image_dict.shape)

        self.bg_raw = json2list('/data/CVPR23_medfm/DATA/priors/COCO_bg.json') #[1, 50]
        write_bg_npy('/data/CVPR23_medfm/DATA/priors/COCO_bg.json',self.text_encoder)
        self.bg_dict = read_npy('COCO_bg.npy',device)
        print ('self.bg_dict:',self.bg_dict.shape)

        #### text aggregation layer 선언
        self.agg_text = MCA(dim=1024)

        #### img aggregation layer 선언
        self.agg_img = MCA(dim=1024)
        
        #### prototype 선언 -> self.prototype (mscoco- 80, VOC - 20,...)
        #a photo of [CLS name]
        #write_prototype_npy(self.text_encoder)
        self.prototype = read_npy('COCO_proto.npy',device)
        print ('self.prototype:',self.prototype.shape)

        self.device = device

        #self.avg_pool = nn.AdaptiveAvgPool1d(1)

        
        

    
    def encode_image(self, x):
        def stem(x):
            for conv, bn in [(self.visual_encoder.conv1, self.visual_encoder.bn1), \
                (self.visual_encoder.conv2, self.visual_encoder.bn2), (self.visual_encoder.conv3, self.visual_encoder.bn3)]:
                x = self.visual_encoder.relu(bn(conv(x)))
            x = self.visual_encoder.avgpool(x)
            return x

        x = x.type(self.visual_encoder.conv1.weight.dtype)
        x = stem(x)
        x = self.visual_encoder.layer1(x)
        x = self.visual_encoder.layer2(x)
        x = self.visual_encoder.layer3(x)
        x = self.visual_encoder.layer4(x)
        return x
    
    def forward(self, image=None, c=None, if_test=False,label=None):

        usage_text_att = 0
        usage_image_att = 0

        if if_test:

            if usage_image_att:
                proto_set = self.prototype  # [cls,L,D]
                proto_features = get_eot_feature_2d(proto_set)
                proto_feat = replicate_cls(proto_features, image.shape[0]) #[B,cls,D]

            image_feat = self.encode_image(image)
            b, c, h, w = image_feat.shape
            
            x = image_feat.reshape(b, c, h * w).permute(2, 0, 1)
            # g_x = x.mean(0, keepdim=True)
            # x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW)xBxC        
            
            x = F.linear(x, self.v_linear_weight, self.v_linear_bias)
            x = F.linear(x, self.c_linear_weight, self.c_linear_bias)
            image_features = x #[49,100,1024] #[HW*B*C]
            image_feature_, _ = self.model.visual.attnpool(image_feat, if_pos=False) #[100,1024] #[B*C]


            if usage_image_att:
                local_feat = image_features.permute(1,0,2) #[B,HW,D]
                global_feat = image_feature_.unsqueeze(1) #[B,1,D]
                images_global_local_feat = torch.cat((global_feat, local_feat), dim=1)
                output = self.agg_img(x_q=images_global_local_feat, x_k=proto_feat, x_v=proto_feat)

                local_feat = local_feat + output[:, 1:, :]  # [B,HW,D]
                global_feat = global_feat + output[:, 0, :].unsqueeze(1)  # [B,1,D]

                image_features = local_feat.permute(1,0,2) #[HW,B,D]
                image_feature_ = global_feat[:,0,:] #[B,D]


            # ===============================================================

            prompts, prompts_double, temperature, spatial_T, rk_scale = self.prompt_learner()
            tokenized_prompts = self.tokenized_prompts
            text_features = self.text_encoder(prompts, tokenized_prompts)
            text_features_neg = self.text_encoder(prompts_double, tokenized_prompts)
            
            
            image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_features_neg = text_features_neg / text_features_neg.norm(dim=-1, keepdim=True)

            logit_scale = temperature.exp()  # rk_scale
            logit_scale = logit_scale if self.cfg.TRAIN.IF_LEARN_SCALE else 4.0 # 50
            logits_ = logit_scale * image_feature_ @ text_features.t()   # B * C,  cls * C, = B * cls
            logits_neg = image_features @ text_features_neg.t()    #  HW * B * C,  cls * C,  HW * B * cls
            
            #print ("image_features:",image_features.shape)
            #print ("image_feat,image_feature_,logits_:", image_feat.shape, image_feature_.shape,logits_.shape)
            
           
            tmp_scale = spatial_T.exp() if self.cfg.TRAIN.IF_LEARN_spatial_SCALE else self.cfg.TRAIN.spatial_SCALE_image  # 5 #
            prob_spatial = torch.nn.functional.softmax(logits_neg * tmp_scale, dim=0)
            logits_local = torch.sum(logit_scale * logits_neg * prob_spatial, dim=0)
            
     
            #print ("logits_neg.shape",logits_neg.shape)

            return logits_, logits_local, logits_neg, image_features @ text_features.t()  # compare additional branch with global proxy
        else:
            
            #print(label.shape)

            #### text conder
            ## Global prompt - prompts
            ## Local prompt - prompts_double

            ## prior knolwedge 불러오기
            proto_set = self.prototype #[cls,L,D]
            proto_features = get_eot_feature_2d(proto_set) #[cls,D]
        
            labeled_prior_set = self.prior_dict #[cls,N,L,D]
            #print (labeled_prior_set.shape)

            prior_features = get_eot_feature(labeled_prior_set)#[cls,N,D]

            ## text aggregation layer 에 통과 시키기
            prompts, prompts_double, temperature, spatial_T, rk_scale = self.prompt_learner()
            tokenized_prompts = self.tokenized_prompts
            text_features = self.text_encoder(prompts, tokenized_prompts)
            text_features_neg = self.text_encoder(prompts_double, tokenized_prompts)

            if usage_text_att:
                text_features = text_features.unsqueeze(1)  #[cls,1,D]
                text_features_neg = text_features_neg.unsqueeze(1)  #[cls,1,D]
                global_local_feat = torch.cat((text_features, text_features_neg),dim=1) #[cls,2,D]

                #print (prior_features.shape,global_local_feat.shape)
            
                output = self.agg_text(x_q= global_local_feat,  x_k=prior_features, x_v = prior_features)
                #[C,2,D]
                text_features = text_features + output[:,0,:].unsqueeze(1) #[cls,1,D]
                text_features_neg = text_features_neg + output[:,1,:].unsqueeze(1) #[cls,1,D]

                text_features = text_features.squeeze(1)
                text_features_neg = text_features_neg.squeeze(1)

            #print(text_features.shape, text_features_neg.shape)
            #print('success text encoder')
            #### img conder

            ##prior knolwedge for image 불러오기
            labeled_image_prior_set = self.image_dict #[cls,N,L,D]
            #print(labeled_image_prior_set.shape)
            image_features = get_eot_feature(labeled_prior_set) #[cls,N,D]

            bg_prior_set = self.bg_dict  # [N,L,D]
            #print(bg_prior_set.shape)
            bg_features = get_eot_feature_2d(bg_prior_set)  # [N,D]


            ##7x7 grid 생성
            images_batch = torch.empty(label.shape[0], 49, 1024).to(self.device)  #[B,HW,D]
            for idx,batch_label in enumerate(label):  #[64,80]
                num_class = np.sum(batch_label.detach().cpu().numpy())
                max_num = int(39 / num_class)

                label_candidates = torch.where(batch_label==1)[0].detach().cpu().numpy()
                class_per_num = []
                add_candidates ={}
                for i in range(num_class):
                    allocation = np.random.randint(1, max_num+1)
                    class_per_num.append(allocation)
                    add_candidates[label_candidates[i]] = allocation

                total_num = sum(class_per_num)
                add_background_feature = 49 - total_num

                draw = 49
                write_idx = 0
                while draw:
                    #print(idx,draw)
                    selected_cls = np.random.randint(0,num_class+1) #0:bg, 1~num_class:cls
                    if selected_cls == 0:
                        if add_background_feature:
                            selected_idx = np.random.randint(0,bg_features.shape[0])
                            images_batch[idx,write_idx]=bg_features[selected_idx].clone()
                            draw = draw -1
                            write_idx+=1
                            add_background_feature = add_background_feature-1
                    else:
                        selected_real_cls = list(add_candidates)[selected_cls - 1]
                        if add_candidates[selected_real_cls] !=0:
                            selected_idx = np.random.randint(0,image_features[selected_real_cls].shape[0])
                            images_batch[idx,write_idx] = image_features[selected_real_cls][selected_idx].clone()
                            draw = draw -1
                            write_idx += 1
                            add_candidates[selected_real_cls] = add_candidates[selected_real_cls]-1
            #print('make images done.')
            # images_batch # [B,HW,D]
            F_ = images_batch
            
            
            
            
            #f generation by attn pooling of F_, #[B,HW,D]
            f = F_.mean(dim=1, keepdim=True) #[B,1,D]
            
            
            F_norm = F_ / F_.norm(dim=-1, keepdim=True)
            f_norm = f / f.norm(dim=-1, keepdim=True)
            
            sim_map = f_norm @ F_norm.t()
            S_matrix = torch.nn.functional.softmax(sim_map, dim= -1) #[B,1,HW]
            f = S_matrix @ sim_map #[B,1,D]
            
            f = f / f.norm(dim=-1, keepdim=True)

            # f making

            # 1) avgpooling
            #f = self.avg_pool(images_batch.permute(0, 2, 1))[...,0] #[B,D]

            # 2) argmax
            #f =  get_eot_feature_2d(F_) # [B,D]

            # 3) sum
            #logit_scale = temperature.exp()  # rk_scale
            #tmp_scale = spatial_T.exp() if self.cfg.TRAIN.IF_LEARN_spatial_SCALE else self.cfg.TRAIN.spatial_SCALE_text
            #f_prob_sptial  =  torch.nn.functional.softmax (F_ * tmp_scale, dim=1) # [B,HW,D]
            #f =  torch.sum (logit_scale *  F_  *  f_prob_sptial , dim=1) # [B,D]

            # 4) select each class and one background sentence ,attach it and embedding for one [B,D]
            
            
            
            
            
            
            
            
            
            
#             f = []
#             global_sentence_list =[]
#             for idx, batch_label in enumerate(label):
#                 #self.image_raw[] # [80,50]
#                 #print (idx)

#                 label_candidates = torch.where(batch_label == 1)[0].detach().cpu().numpy()

#                 global_sentence = ""

#                 for label_c in label_candidates:
#                     selected_idx = np.random.randint(0,len(self.image_raw[label_c]))
#                     selected_sentence = self.image_raw[label_c,selected_idx] # one sentence

#                     global_sentence += (selected_sentence+",")

#                 selected_idx = np.random.randint(0, len(self.bg_raw[0]))
#                 global_sentence += self.bg_raw[0][selected_idx]
#                 #print(global_sentence)

#                 global_sentence_list.append(global_sentence)


#             tokenized_prompts = [clip.tokenize(p,truncate=True) for p in global_sentence_list]  # [B,L]
#             tokenized_prompts = torch.stack(tokenized_prompts) #[B,1,L]
#             tokenized_prompts = tokenized_prompts.squeeze().to(self.device) #[B,L]
#             #print (tokenized_prompts.shape)
#             with torch.no_grad():
#                 global_features = self.text_encoder(tokenized_prompts, None, if_embedding=False, if_sequence=False)
#                 #print(global_features.shape)  # [B,D]
#                 f.append(global_features)

#             f = torch.stack(f,dim=0) #[B,1,D]
#             f = f.squeeze() #[B,D]









            #print (f.shape)

            ## img aggregation layer에 통과 시키기 그리고  prototype 불러오기

            image_feature_global = f #.unsqueeze(1)  # [B,1,D]
            image_feature_local = F_  # [B,HW,D]
            
            #image_feature_ = image_feature_global #image_feature_ / image_feature_.norm(dim=-1, keepdim=True)
            image_feature_local =  image_feature_local / image_feature_local.norm(dim=-1, keepdim=True)

            

            if usage_image_att:
                images_global_local_feat = torch.cat((image_feature_global, image_feature_local), dim=1)  # [B,1+HW,D]

                proto_feat = replicate_cls(proto_features,label.shape[0])#[B,cls,D]

                output = self.agg_img(x_q=images_global_local_feat, x_k=proto_feat, x_v=proto_feat)

                image_feature_global = image_feature_global + output[:,0,:].unsqueeze(1) # [B,1,D]
                image_feature_local = image_feature_local + output[:,1:,:] # [B,HW,D]

                image_feature_global = image_feature_global.squeeze(1)

            image_feature_ = image_feature_global
            image_features = image_feature_local

            #print (image_features.shape, image_feature_.shape )

            #print ('end of our code.')

            #origin_code
            #================================================================
            #image_feat = self.text_encoder(captions, None, if_embedding=False, if_sequence=True) 
            # b, l, d = image_feat.shape
            #image_feature_ = image_feat[torch.arange(image_feat.shape[0]), captions.argmax(dim=-1)]  # BD
            #image_features = image_feat.permute(1, 0, 2)  # LBD

            #prompts, prompts_double, temperature, spatial_T, rk_scale = self.prompt_learner()
            #tokenized_prompts = self.tokenized_prompts
            #text_features = self.text_encoder(prompts, tokenized_prompts)
            #text_features_neg = self.text_encoder(prompts_double, tokenized_prompts)
            
            # ===============================================================

            image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_features_neg = text_features_neg / text_features_neg.norm(dim=-1, keepdim=True)
            
            # mask irrelavent tokens
            #text_mask = (captions == 0).long() * (-10000)  # BL

            logit_scale = temperature.exp()  # rk_scale
            logit_scale = logit_scale if self.cfg.TRAIN.IF_LEARN_SCALE else 4.0 # 50 # temperature.exp()  # self.logit_scale.exp()
            logits_ = logit_scale * image_feature_ @ text_features.t()   # B * C,  cls * C, = B * cls
            logits_neg = image_features.permute(1,0,2) @ text_features_neg.t()    #  L * B * C,  cls * C =  L * B * cls
            #logits_neg = logits_neg.permute(2, 1, 0) + text_mask[None, :, :] #  cls*B*L
            #logits_neg = logits_neg.permute(2, 1, 0)

            tmp_scale = spatial_T.exp() if self.cfg.TRAIN.IF_LEARN_spatial_SCALE else self.cfg.TRAIN.spatial_SCALE_text
            prob_spatial = torch.nn.functional.softmax(logits_neg * tmp_scale, dim=0)
            # logits_neg : torch.Size([49, 512, 80]) => prob_spatial : torch.Size([49, 512, 80])

            logits_local = torch.sum(logit_scale * logits_neg * prob_spatial, dim=0)
            # logits_neg : torch.Size([49, 512, 80]) * prob_spatial : torch.Size([49, 512, 80]) => [512,80]
            # [B*cls]
        
            
            return logits_, logits_local, image_features, text_features
        
        
##TODO - cross attention
class MCA(nn.Module):
    def __init__(self, dim=192, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self._reset_parameters()

    def _reset_parameters(self):
        torch.manual_seed(0)
        nn.init.xavier_uniform_(self.q.weight)
        nn.init.xavier_uniform_(self.k.weight)
        nn.init.xavier_uniform_(self.v.weight)
        nn.init.xavier_uniform_(self.proj.weight)
        if self.k.bias is not None:
            nn.init.xavier_normal_(self.k.bias)
        if self.v.bias is not None:
            nn.init.xavier_normal_(self.v.bias)
        if self.proj.bias is not None:
            nn.init.constant_(self.proj.bias, 0.)

    def forward(self, x_q, x_k, x_v):
        B, N_q, C = x_q.shape
        _, N_kv, C = x_k.shape
        _, N_kv, C = x_v.shape

        # b, h, n, d
        q = self.q(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(x_k).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x_v).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        # [b, h, n, d] * [b, h, d, m] -> [b, h, n, m]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N_q, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x




# kl_loss = nn.KLDivLoss(reduction="batchmean")
# ce_loss = torch.nn.CrossEntropyLoss()

@TRAINER_REGISTRY.register()
class Caption_ours(TrainerX):
    def model_inference(self, input):
        return self.model(input, if_test=True)
        # return self.model(None, input)

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")

        # images = []
        # labels = []
        # outputs = []
        # output_poss = []
        # spatial_logits = []
        # global_spatial_logits = []
        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            # output = self.model_inference(input)
            output, output_pos, image_features_, text_features_ = self.model_inference(input)
            #visualize_with_spatial_logits(batch_idx,input, image_features_, label, output_pos)
            
            #if batch_idx == 10:
            #    exit(-1)
            
            self.evaluator.process(output, label, output_pos)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]

    def check_cfg(self, cfg):
        assert cfg.TRAINER.Caption.PREC in ["fp16", "fp32", "amp"]

    def build_model(self):
        print('==================== Building model in Caption_distill_double ======================')
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        
        if cfg.TRAINER.Caption.PREC == "fp32" or cfg.TRAINER.Caption.PREC == "amp":
            # CLIP's default precision is fp16
            clip_model.float()

        print("Building custom CLIP")
        # self.model = CustomCLIP(cfg, classnames, clip_model)
        self.model = DenseCLIP(cfg, classnames, clip_model, self.device)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
                                                            
        # load_pretrained_weights(self.model.prompt_learner, 'output/voc2007_caption_distill_abinf/Caption_distill_double/rn50_fixscale/nctx16_cscFalse_ctpend/seed3/prompt_learner/model-best.pth.tar')
        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)

        self.model.to(self.device)
        # NOTE: only give prompt_learner to the optimizer
        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM, self.model.agg_img, self.model.agg_text)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)

        self.scaler = GradScaler() if cfg.TRAINER.Caption.PREC == "amp" else None

        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        device_count = torch.cuda.device_count()
        if device_count > 1:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)

    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)
        
        prec = self.cfg.TRAINER.Caption.PREC
        if prec == "amp":
            with autocast():
                output, output_local, _, _ = self.model(image)
                loss = F.cross_entropy(output, label)
            self.optim.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optim)
            self.scaler.update()
        else:
            output, output_local, _, _ = self.model(None, image,False,label)
            if   self.cfg.TRAIN.LOSSFUNC == 'sigmoid':
                loss = norm_logits_BCEloss(output, label.float()) + norm_logits_BCEloss(output_local, label.float())
            elif self.cfg.TRAIN.LOSSFUNC == 'focal':
                loss = sigmoid_focal_loss(output, label)
            elif self.cfg.TRAIN.LOSSFUNC == 'asl':
                loss = ASL_loss(output, label) + ASL_loss(output_local, label.float())
            elif self.cfg.TRAIN.LOSSFUNC == 'ranking':
                loss = ranking_loss(output, label)
            elif self.cfg.TRAIN.LOSSFUNC == 'double_ranking':
                #loss = ranking_loss(output, label, scale_ = 1.0, margin_ = 1) + ranking_loss(output_local, label, scale_ = 1.0, margin_ = 1)
                a = ranking_loss(output, label, scale_ = 1.0, margin_ = 1)
                b = ranking_loss(output_local, label, scale_ = 1.0, margin_ = 1)
                print (a,b)
                loss =a+b



            else:
                loss = soft_cross_entropy(output, label)

            self.model_backward_and_update(loss)

        loss_summary = {
            "loss": loss.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    def load_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # By default, the best model is loaded
        model_file = "model-best.pth.tar"

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
            # set strict=False
            self._models[name].load_state_dict(state_dict, strict=False)
            print(state_dict)
