import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.model import convert_weights

from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
from tqdm import tqdm
import numpy as np
import math
import time
import json
import os
import torchvision.transforms as transforms

CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}."
}
label_names = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
    "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
    "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
    "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
    "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
    "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
    "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
    "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
    "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
    "hair drier", "toothbrush"
]

import sys



MEASURE_TIME = False
WTA = False
RESNET101 = False
TEMP1 = 60.0
TEMP2 = 40.0
#K1=1000
#K2=20000
#K2_1=200
#K2_2=100
K1=5
K2=5
K2_1=25
K2_2=20


DATASET = "NUSWIDE"

TASK_ADAPT_LOCAL = True
TASK_ADAPT_GLOBAL = True

CLASS_ADAPT_LOCAL = True
CLASS_ADAPT_GLOBAL = True

IMG_CNT = 0
ROOT_PATH ='./fig5_nuswide'


@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()

        return logits, None, None


def read_npy(numpy_file, device):
    all_features = np.load(numpy_file, allow_pickle=True)

    # for ii, l_n in enumerate(label_names):
    #     temp = torch.tensor(all_features[ii])
    #     if ii == 0:
    #         text_caption = temp[:16].unsqueeze(0)
    #     else:
    #         text_caption = torch.cat((text_caption, temp[:16].unsqueeze(0)),0) # 80 16 512

    print('load complete:', numpy_file)
    return torch.tensor(all_features).to(torch.float16).to(
        device)  # text_caption.to(device) #torch.tensor(all_features).to(device)


@TRAINER_REGISTRY.register()
class ZeroshotCLIP_dense(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

        self.visual_encoder = IntermediateLayerGetter(self.clip_model.visual, {"layer4": "0"})
        self.positional_embedding = self.clip_model.visual.attnpool.positional_embedding[1::]
        self.v_linear_weight = self.clip_model.visual.attnpool.v_proj.weight
        self.v_linear_bias = self.clip_model.visual.attnpool.v_proj.bias
        self.c_linear_weight = self.clip_model.visual.attnpool.c_proj.weight
        self.c_linear_bias = self.clip_model.visual.attnpool.c_proj.bias

        if DATASET == "COCO":
            if RESNET101:
                self.caption = read_npy('COCO_caption_selection_101.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self.caption_local = read_npy('COCO_caption_selection_local_101.npy', self.device)
                self.caption_local = self.caption_local / self.caption_local.norm(dim=-1, keepdim=True)

                #self.multi_label_caption = read_npy('COCO_caption_selection_fig4_40000_101.npy',self.device)
                
                
                self._multi_label_caption = read_npy('COCO_caption_selection_fig4_40000_101.npy',self.device)
                
                
                file = open('COCO_caption_selection_fig4_40000_sentence_101.json', 'r') 
                self.sentence = json.load(file)
                
                #self._multi_label_caption = read_npy('COCO_multi_label_caption_selection_101_v3.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
            else:
                self.caption = read_npy('COCO_caption_selection.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self.caption_local = read_npy('COCO_caption_selection_local.npy', self.device)
                self.caption_local = self.caption_local / self.caption_local.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('COCO_multi_label_caption_selection_50_v3.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)

        elif DATASET == "NUSWIDE":
            if RESNET101:
                self.caption = read_npy('nus_llm_caption_selection_101_87500_one_cls.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('nus_llm_multi_label_caption_selection_101_87500_one_cls.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
                #nus_multi_label_caption_selection_101_v3_7000
            else:
                """
                self.caption = read_npy('nus_llm_caption_selection_50_27_10000.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('nus_llm_caption_selection_50_27_10000.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
                """
                self.caption = read_npy('nus_llm_caption_selection_50_57600_one_cls.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                #self._multi_label_caption = read_npy('nus_llm_multi_label_caption_selection_50_127500_one_cls.npy', self.device)
                
                self._multi_label_caption = read_npy('nus_llm_multi_label_caption_selection_50_fig4_57600.npy',self.device)
                #self._multi_label_caption = read_npy('nus_llm_multi_label_caption_selection_50_57600_one_cls.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
                
                file = open('nus_llm_multi_label_caption_selection_50_fig4_57600_sentence.json', 'r') 
                self.sentence = json.load(file)
                
                print('self._multi_label_caption', self._multi_label_caption.shape)
                print('self.sentence:',len(self.sentence))
                


        elif DATASET == "VOC":
            if RESNET101:
                self.caption = read_npy('VOC_caption_selection_100_101.npy', self.device)
                #self.caption = read_npy('VOC_caption_selection_min_101.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)
                self._multi_label_caption = read_npy('VOC_multi_label_caption_selection_101_v7_60000.npy', self.device)
                #self._multi_label_caption = read_npy('VOC_multi_label_caption_selection_101_v5_40000.npy', self.device)
                #self._multi_label_caption = read_npy('VOC_multi_label_caption_selection_101_v6_40000.npy', self.device) 
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
               # self._multi_label_caption = self._multi_label_caption.permute(1,0,2) 200 100 512
            else:
                self.caption = read_npy('VOC_caption_selection.npy', self.device)
                self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True)

                self._multi_label_caption = read_npy('VOC_multi_label_caption_selection_50_v3.npy', self.device)
                self._multi_label_caption = self._multi_label_caption / self._multi_label_caption.norm(dim=-1,
                                                                                                       keepdim=True)
               # self._multi_label_caption = self._multi_label_caption.permute(1,0,2)#  .view(100,200,-1)
    def encode_image(self, x):
        def stem(x):
            for conv, bn in [(self.visual_encoder.conv1, self.visual_encoder.bn1), \
                             (self.visual_encoder.conv2, self.visual_encoder.bn2),
                             (self.visual_encoder.conv3, self.visual_encoder.bn3)]:
                x = self.visual_encoder.relu(bn(conv(x)))
            x = self.visual_encoder.avgpool(x)
            return x

        x = x.type(self.visual_encoder.conv1.weight.dtype)
        x = stem(x)
        x = self.visual_encoder.layer1(x)
        x = self.visual_encoder.layer2(x)
        x = self.visual_encoder.layer3(x)
        x = self.visual_encoder.layer4(x)
        return x

    def model_inference(self, image, label):
        global TEMP1,TEMP2,IMG_CNT
        image_feat = self.encode_image(image)
        b, c, h, w = image_feat.shape
        x = image_feat.reshape(b, c, h * w).permute(2, 0, 1)

        x = F.linear(x, self.v_linear_weight, self.v_linear_bias)
        x = F.linear(x, self.c_linear_weight, self.c_linear_bias)
        image_features = x

        image_feature_, _ = self.clip_model.visual.attnpool(image_feat)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)
        
        
        
        
        def get_global_p2(caption, image_feature_, temp, K, K1, K2):
            print (K,K1,K2)
            #assert K == K1 * K2, "K must be equal to K1 * K2"

            batch_size = image_feature_.shape[0]

            # 유사도 계산
            logit = torch.matmul(caption, image_feature_.t())  # [1280, batch_size]

            # 각 이미지 특징에 대해 상위 K개의 유사도 점수와 해당 인덱스 추출
            topk_values, topk_indices = torch.topk(logit, k=K, largest=True, dim=0)  # [K, batch_size], [K, batch_size]
            lowk_values, lowk_indices = torch.topk(logit, k=K, largest=False, dim=0)
            
            topk_indices = topk_indices.detach().cpu().numpy()
            lowk_indices = lowk_indices.detach().cpu().numpy()
            
            labels_idx = torch.where(label==1)[0].detach().cpu().numpy()
            
            f =open(os.path.join(ROOT_PATH,str(IMG_CNT)+".txt"),'w')
            
            
            # 상위 문장과 값
            f.write('top sentences\n')
            top_sentences = np.array(self.sentence)[topk_indices]
            for sentence, value in zip(top_sentences, topk_values):
                f.write(f"{sentence}: {value}\n")

            f.write('\nlow sentences\n')
            low_sentences = np.array(self.sentence)[lowk_indices]
            for sentence, value in zip(low_sentences, lowk_values):
                f.write(f"{sentence}: {value}\n")
                
            f.write('\nlabels\n')

               
            
            
        
        def normalize_and_save_image(image_tensor, filename):
            # 이미지 텐서의 크기와 타입을 확인
            print(image_tensor.shape)
            image_tensor = image_tensor.squeeze(0)  # 배치 차원 제거

            image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())

            # 이미지 텐서를 PIL 이미지로 변환
            unloader = transforms.ToPILImage()
            image = unloader(image_tensor)

            # 이미지 저장
            image.save(filename)

        # 이미지 정규화 및 저장
        normalize_and_save_image(image, os.path.join(ROOT_PATH,'Img_'+str(IMG_CNT)+'.png'))
        
        
        multi_step_input = self._multi_label_caption
        multi_step_input = multi_step_input.reshape(-1,1024)
        
        
        global_p = get_global_p2(multi_step_input, image_feature_,TEMP2,K2,K2_1,K2_2)

        logits_ = image_feature_ @ self.text_features.t() 
        logits = image_features @ self.text_features.t()
        IMG_CNT+=1
        print (IMG_CNT)

        return logits_, logits, None, None

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")

        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            # output = self.model_inference(input)

            if MEASURE_TIME:
                st = time.time()

            output, output_pos, image_features_, text_features_ = self.model_inference(input, label)

            if MEASURE_TIME:
                ed = time.time()
                print(ed - st)

            self.evaluator.process(output, label, output_pos)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]


@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
    """Prompt ensembling."""

    # templates = IMAGENET_TEMPLATES
    templates = IMAGENET_TEMPLATES_SELECT

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        for params in clip_model.parameters():
            params.requires_grad_(False)

        # add custom-made prompt
        if cfg.DATASET.NAME != "ImageNet":
            self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]

        num_temp = len(self.templates)
        print(f"Prompt ensembling (n={num_temp})")

        mean_text_features = 0
        for i, temp in enumerate(self.templates):
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
            prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            mean_text_features = mean_text_features + text_features
        mean_text_features = mean_text_features / num_temp
        mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)

        self.text_features = mean_text_features
        self.clip_model = clip_model
