import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.model import convert_weights

from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patheffects as pe

VISUALIZE = False
ERASE_FP = False
FILTER_FP = False


label_names = [
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
        "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
        "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
        "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
        "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
        "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
        "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
        "hair drier", "toothbrush"
    ]


CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}."
}


@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()

        return logits, None, None


def entropy_from_logits(logits):
    # 로그 확률을 직접 계산
    log_probs = F.log_softmax(logits, dim=-1)
    probs = torch.exp(log_probs)
    # 확률 분포에 대한 entropy 계산
    ent = -torch.sum(probs * log_probs, dim=-1)
    return ent

def read_npy(numpy_file, device):
    all_features = np.load(numpy_file, allow_pickle=True)

    # for ii, l_n in enumerate(label_names):
    #     temp = torch.tensor(all_features[ii])
    #     if ii == 0:
    #         text_caption = temp[:16].unsqueeze(0)
    #     else:
    #         text_caption = torch.cat((text_caption, temp[:16].unsqueeze(0)),0) # 80 16 512

    print('load complete:', numpy_file)
    return torch.tensor(all_features).to(torch.float16).to(
        device)

@TRAINER_REGISTRY.register()
class ZeroshotCLIP_dense(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)
        
        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."
            #temp = "a photo of a part of {}"
            #temp = "a photo of the small {}."
            #temp = "a photo of no {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

        self.visual_encoder = IntermediateLayerGetter(self.clip_model.visual, {"layer4": "0"})
        self.positional_embedding = self.clip_model.visual.attnpool.positional_embedding[1::]
        self.v_linear_weight = self.clip_model.visual.attnpool.v_proj.weight
        self.v_linear_bias = self.clip_model.visual.attnpool.v_proj.bias
        self.c_linear_weight = self.clip_model.visual.attnpool.c_proj.weight
        self.c_linear_bias = self.clip_model.visual.attnpool.c_proj.bias

        self.caption = read_npy('COCO_caption_selection_101.npy', self.device)
        self.caption = self.caption / self.caption.norm(dim=-1, keepdim=True) #[80,16,1024]


    def encode_image(self, x):
        def stem(x):
            for conv, bn in [(self.visual_encoder.conv1, self.visual_encoder.bn1), \
                (self.visual_encoder.conv2, self.visual_encoder.bn2), (self.visual_encoder.conv3, self.visual_encoder.bn3)]:
                x = self.visual_encoder.relu(bn(conv(x)))
            x = self.visual_encoder.avgpool(x)
            return x

        x = x.type(self.visual_encoder.conv1.weight.dtype)
        x = stem(x)
        x = self.visual_encoder.layer1(x)
        x = self.visual_encoder.layer2(x)
        x = self.visual_encoder.layer3(x) # torch.Size([B, 512, 28, 28])
        x = self.visual_encoder.layer4(x) # torch.Size([B, 1024, 14, 14])
        return x # torch.Size([B, 2048, 7, 7])

    def model_inference(self, image):
        image_feat = self.encode_image(image) #[B,C,H,W]
        b, c, h, w = image_feat.shape
        x = image_feat.reshape(b, c, h * w).permute(2, 0, 1) # [7x7,B,2048]

        x = F.linear(x, self.v_linear_weight, self.v_linear_bias)
        x = F.linear(x, self.c_linear_weight, self.c_linear_bias)
        image_features = x
        
        image_feature_, _ = self.clip_model.visual.attnpool(image_feat) #[7x7,B,2048]-> #[B,1024]
        
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True)

        caption_logit = torch.matmul(self.caption, image_feature_.t())
        A = torch.matmul(self.caption, image_feature_.t())  # 80 16 100
        A_up = F.softmax(A, dim=1)

        logit_scale = self.clip_model.logit_scale.exp() # logit_scale = self.clip_model.logit_scale.exp()
        #logit_scale = 10
        #logits_ = logit_scale * image_feature_ @ self.text_features.t()   # B * cls  = B * C,  cls * C

        logits_ = logit_scale * (torch.sum(A_up * caption_logit, dim=1).permute(1,
                                                                                0) + image_feature_ @ self.text_features.t()) / 2.0


        logits = logit_scale * image_features @ self.text_features.t()    #  HW * B * cls = HW * B * C,  cls * C
        patch_logits = logits.clone()
 
        prob_spatial = torch.nn.functional.softmax(logits, dim=0)  # [7x7,B,cls] #
        logits = torch.sum(logits * prob_spatial, dim=0) #[B,cls] = [7x7,B,cls] * [7x7,B,cls]

        return logits_, logits, patch_logits , None, None

    def attnpool_sliding_window(self, feature_map, input_window, output_size):
        """
        Slides a window over the feature_map and applies the attention pooling.

        :param feature_map: Input feature map. Shape: [B, C, H, W]
        :param input_window: Size of the window to slide. For 5x5 window, it's 5.
        :param output_size: The resulting feature map size. For 3x3 result, it's 3.
        :return: Processed feature map. Shape: [B, output_size, output_size, 1024]
        """
        b, c, h, w = feature_map.shape
        stride = (h - input_window) // (output_size - 1)

        pooled_features = []
        for i in range(output_size):
            for j in range(output_size):
                # Extract the window
                sub_feat = feature_map[:, :, i * stride:i * stride + input_window, j * stride:j * stride + input_window]

                # Apply attnpool
                pooled, _ = self.clip_model.visual.attnpool(sub_feat)  # [B, 1024]
                pooled_features.append(pooled)

        # Reshape to the desired output size
        pooled_features = torch.stack(pooled_features, dim=1).reshape(b, output_size, output_size, -1)
        pooled_features = pooled_features.to(feature_map.device)
        return pooled_features

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")

        folder_path = 'visualize/'
        # 폴더가 없으면 생성
        import os
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)


        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            # output = self.model_inference(input)
            output, output_pos, patch_output, image_features_, text_features_ = self.model_inference(input)

            if ERASE_FP:
                # Get the top-5 predicted classes for each batch
                top5_indices = output.topk(5, dim=1).indices

                for i in range(output.size(0)):
                    # Find the active indices in the label
                    active_labels = (label[i] == 1).nonzero(as_tuple=True)[0].tolist()
                    # Find the predicted classes that are not in the active labels
                    not_in_label = [idx for idx in top5_indices[i] if idx.item() not in active_labels]
                    # Set the output values for those classes to -1
                    for idx in not_in_label:
                        output[i, idx] = -1
                    #filename = "batch:"+str(batch_idx)+"idx:"+str(i)+"_zs.png"
                    #plt.savefig(os.path.join(folder_path, filename))


            image_feat = self.encode_image(input)  # [B,C,H,W]
            image_feat = image_feat.permute(2,3,0,1).clone()  # [7,7,B,C]

            # torch.Size([100, 2048, 7, 7])

            bs = image_feat.shape[-2]

            MAX_GRID_SIZE = 14

            #our_local = torch.empty(MAX_GRID_SIZE, MAX_GRID_SIZE, image_feat.shape[-2],80).to(self.device)
            our_max_logit = torch.empty(MAX_GRID_SIZE, MAX_GRID_SIZE, image_feat.shape[-2]).to(self.device)
            our_max_index = torch.empty(MAX_GRID_SIZE, MAX_GRID_SIZE, image_feat.shape[-2]).to(self.device)
            # torch.Size([14, 14, 100])
            our_max_anchors = torch.empty(MAX_GRID_SIZE, MAX_GRID_SIZE, image_feat.shape[-2],80).to(self.device)

            for row in range(0, image_feat.shape[0], 1):
                for col in range(0, image_feat.shape[1], 1):

                    min_idx = 0
                    max_idx = MAX_GRID_SIZE-1

                    pooling_list = []
                    pooling_id = []

                    # 1x1 (1)
                    pooling_list.append(image_feat[row, col])
                    pooling_id.append(0)

                    # 2x2 (1)
                    if row + 1 <= max_idx and col + 1 <= max_idx:
                       pooling_list.append(image_feat[row:row + 2, col:col + 2])
                       pooling_id.append(1)
                        
                    # 2x1 (2)
                    if row - 1 >= min_idx:
                       pooling_list.append(image_feat[row - 1:row+1,col])
                       pooling_id.append(2)

                    # 2x1 (3)
                    if row + 1 <= max_idx:
                       pooling_list.append(image_feat[row :row+2,col])
                       pooling_id.append(3)

                    # 1x2 (4)
                    if col - 1 >= min_idx:
                       pooling_list.append(image_feat[row,col-1:col+1])
                       pooling_id.append(4)

                    # 1x2 (5)
                    if col + 1 <= max_idx:
                       pooling_list.append(image_feat[row,col:col+2])
                       pooling_id.append(5)

                    # 2x2 (6)
                    if col - 1 >= min_idx and row -1 >= min_idx:
                       pooling_list.append(image_feat[row-1:row+1,col-1:col+1])
                       pooling_id.append(6)

                    # 2x2 (7)
                    if col + 1 <= max_idx and row - 1 >= min_idx:
                       pooling_list.append(image_feat[row - 1:row+1,col:col+2])
                       pooling_id.append(7)

                    # 2x2 (8)
                    if col -1 >= min_idx and row+ 1 <=max_idx:
                       pooling_list.append(image_feat[row:row+2,col-1:col+1])
                       pooling_id.append(8)

                    # 2x2 (9)
                    if col + 1 <= max_idx and row + 1 <= max_idx:
                       pooling_list.append(image_feat[row:row+2,col:col + 2])
                       pooling_id.append(9)

                    # 3x3 (10)
                    if col + 1 <= max_idx and row + 1 <= max_idx and col -1 >=min_idx and row-1 >=min_idx:
                       pooling_list.append(image_feat[row-1:row + 2,col-1:col + 2])
                       pooling_id.append(10)


                    logit_list = []
                    entropy_list = []


                    #print ('new set', row,col)
                    for pool_idx, pooling_source in enumerate(pooling_list):

                        pool_id_ = pooling_id[pool_idx]

                        if pool_idx == 0:
                            pooling_source = F.linear(pooling_source, self.v_linear_weight, self.v_linear_bias)
                            patch_pooled = F.linear(pooling_source, self.c_linear_weight, self.c_linear_bias)
                        elif pooling_source.dim() == 3:
                            pooling_source = pooling_source.unsqueeze(0).permute(2, 3, 0, 1).clone()
                        else:
                            pooling_source = pooling_source.permute(2, 3, 0, 1).clone()


                        if pool_idx != 0:
                            if pool_id_==1:
                                #print(pool_id_, pooling_source.shape)
                                #patch_pooled, _ = self.clip_model.visual.attnpool4(pooling_source)
                                
                              #  print(pooling_source.shape)
                                #b c hw - hw b c
                                #b c 1 - 1 b c
                                #b 1 c x b c hw -> b 1 hw
                                # b 1 hw x b hw c -> b 1 c
                                
                                
                                b,c,h,w = pooling_source.shape
                                pooling_source_v = F.linear(pooling_source.reshape(b,c,h*w).permute(2,0,1), self.v_linear_weight, self.v_linear_bias)
                              #  print(pooling_source_v.mean(0).shape)
                             #   print(F.softmax(pooling_source_v.mean(0).unsqueeze(0)* pooling_source_v, dim = 0).permute(1,0,2).shape)
                             #   print(pooling_source_v.permute(1,0,2).shape)

                                temp = torch.sum(F.softmax(pooling_source_v.mean(0).unsqueeze(0)* pooling_source_v, dim = 0).permute(1,0,2)*pooling_source_v.permute(1,0,2),dim=1).unsqueeze(0) #.permute(1,0,2)
                                patch_pooled = F.linear(temp, self.c_linear_weight, self.c_linear_bias)#.permute(2,0,1)
                   

                            elif pool_id_>=2 and pool_id_<=5:
                                #print(pool_id_, pooling_source.shape)
                                #patch_pooled, _ = self.clip_model.visual.attnpool8(pooling_source)
                                b,c,h,w = pooling_source.shape
                                pooling_source_v = F.linear(pooling_source.reshape(b,c,h*w).permute(2,0,1), self.v_linear_weight, self.v_linear_bias)
                                b,c,h,w = pooling_source.shape
                                temp = torch.sum(F.softmax(pooling_source_v.mean(0).unsqueeze(0)* pooling_source_v, dim = 0).permute(1,0,2)*pooling_source_v.permute(1,0,2),dim=1).unsqueeze(0) #.permute(1,0,2)

                                patch_pooled = F.linear(temp, self.c_linear_weight, self.c_linear_bias)#.permute(2,0,1)
                            else:
                                #print(pool_id_, pooling_source.shape)
                              #  patch_pooled, _ = self.clip_model.visual.attnpool16(pooling_source)
                                b,c,h,w = pooling_source.shape
                                pooling_source_v = F.linear(pooling_source.reshape(b,c,h*w).permute(2,0,1), self.v_linear_weight, self.v_linear_bias)
                                b,c,h,w = pooling_source.shape
                                temp = torch.sum(F.softmax(pooling_source_v.mean(0).unsqueeze(0)* pooling_source_v, dim = 0).permute(1,0,2)*pooling_source_v.permute(1,0,2),dim=1).unsqueeze(0) #.permute(1,0,2)
                                patch_pooled = F.linear(temp, self.c_linear_weight, self.c_linear_bias)#.permute(2,0,1)

                        patch_pooled = patch_pooled / patch_pooled.norm(dim=-1, keepdim=True)
                        patch_pooled = patch_pooled.squeeze()

                        logit_scale = self.clip_model.logit_scale.exp()
                        raw_logit = logit_scale * patch_pooled @ self.text_features.t()  # [B ,cls] = [B * C][cls * C]

                        A = torch.matmul(self.caption, patch_pooled.t()) #[cls,16,C] [C,B] = [cls,16,B]
                        A_up = F.softmax(A, dim=1) # [cls,16,B]
                        caption_logit = logit_scale * torch.sum(A_up * A,dim=1) # [cls,B]
                        caption_logit = caption_logit.t() #[B,cls]

                        patch_logit = (raw_logit+caption_logit)/2


                        patch_entropy = entropy_from_logits(patch_logit)
                        #print (patch_logit.shape, patch_entropy.shape)

                        logit_list.append(patch_logit)
                        entropy_list.append(patch_entropy)

                    

                    max_values_list = []
                    max_indices_list = []

                    logit_list = torch.stack(logit_list, dim=0)  # [4, 100, 80]

                    if torch.isnan(logit_list).any():
                        print("There are NaNs in logit_list")
                        exit(-1)

                    for logit_ in logit_list:  # [100,80]
                        value, indices = torch.max(logit_, dim=1)  # [100]
                        max_values_list.append(value)
                        max_indices_list.append(indices)

                        # print (value,indices)

                    # 4개의 예측 중 최대 logit 값을 스택하여 텐서로 변환
                    stacked_max_values = torch.stack(max_values_list, dim=0)
                    stacked_max_indices = torch.stack(max_indices_list, dim=0)

                    # 각 배치에서 4개의 예측 중 최대 logit 값을 찾음
                    # print(stacked_max_values)
                    final_max_values, indices_at_max = torch.max(stacked_max_values, dim=0)
                    # print (stacked_max_values.shape)

                    #print (indices_at_max.shape)

                    #final_idx = stacked_max_indices[indices_at_max, torch.arange(stacked_max_indices.size(1))]
                    #final_logit = final_max_values
                    # print (final_idx.shape, final_logit.shape)


                    bs=indices_at_max.shape[0]

                    index_list = indices_at_max.view(1, bs, 1).expand(-1, -1, 80) 

                    #print (index_list.shape) #[1,100,80]

                    selected_logits = torch.gather(logit_list, 0, index_list)

                    #print (selected_logits.shape) #[1,100,80]

                    our_max_anchors [row][col] = selected_logits.squeeze(0)
            
            
            
            logits = our_max_anchors.reshape(-1,bs,80) # [7x7,B,cls]
            prob_spatial = torch.nn.functional.softmax(logits, dim=0)  # [7x7,B,cls] #
            output_pos = torch.sum(logits * prob_spatial, dim=0) #[B,cls] = [7x7,B,cls] * [7x7,B,cls]
            
            
            #output_pos = our_output
            self.evaluator.process(output, label, output_pos)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]


@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
    """Prompt ensembling."""

    # templates = IMAGENET_TEMPLATES
    templates = IMAGENET_TEMPLATES_SELECT

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        for params in clip_model.parameters():
            params.requires_grad_(False)

        # add custom-made prompt
        if cfg.DATASET.NAME != "ImageNet":
            self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]

        num_temp = len(self.templates)
        print(f"Prompt ensembling (n={num_temp})")

        mean_text_features = 0
        for i, temp in enumerate(self.templates):
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
            prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            mean_text_features = mean_text_features + text_features
        mean_text_features = mean_text_features / num_temp
        mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)

        self.text_features = mean_text_features
        self.clip_model = clip_model
