import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.model import convert_weights

from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
from tqdm import tqdm
import numpy as np
CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}."
}


@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()

        return logits, None, None

def read_npy(numpy_file, device):
    all_features = np.load(numpy_file)
    print('load complete:', numpy_file)
    return torch.tensor(all_features).to(device)
def get_eot_feature_2d(input):  # [N, L, D]
    N, L, D = input.shape
    selected_indices = input.argmax(dim=1)  # Find the argmax indices along the L dimension
    # Create expanded indices for N and D
    n_indices = torch.arange(N).unsqueeze(1).to(selected_indices.device)
    d_indices = torch.arange(D).unsqueeze(0).to(selected_indices.device)
    features = input[n_indices, selected_indices, d_indices]
    return features

def replicate_cls(cls_tensor, batch_size):
    # cls_tensor는 [cls, D] 형태의 텐서여야 합니다.
    cls_expanded = cls_tensor.unsqueeze(0).expand(batch_size, -1, -1)
    return cls_expanded

@TRAINER_REGISTRY.register()
class ZeroshotCLIP_dense(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)
        
        try:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        except:
            print('!! WARNING: Not found template for {}'.format(cfg.DATASET.NAME))
            temp = "a photo of a {}."

        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

        self.visual_encoder = IntermediateLayerGetter(self.clip_model.visual, {"layer4": "0"})
        self.positional_embedding = self.clip_model.visual.attnpool.positional_embedding[1::]
        self.v_linear_weight = self.clip_model.visual.attnpool.v_proj.weight
        self.v_linear_bias = self.clip_model.visual.attnpool.v_proj.bias
        self.c_linear_weight = self.clip_model.visual.attnpool.c_proj.weight
        self.c_linear_bias = self.clip_model.visual.attnpool.c_proj.bias
        
      #  self.prototype = read_npy('COCO_proto.npy', self.device)
        
        #[80, 50, 77, 1024]
        
    #    print('self.prior_dict:', self.prior_dict.shape)
    

    def encode_image(self, x):
        def stem(x):
            for conv, bn in [(self.visual_encoder.conv1, self.visual_encoder.bn1), \
                (self.visual_encoder.conv2, self.visual_encoder.bn2), (self.visual_encoder.conv3, self.visual_encoder.bn3)]:
                x = self.visual_encoder.relu(bn(conv(x)))
            x = self.visual_encoder.avgpool(x)
            return x

        x = x.type(self.visual_encoder.conv1.weight.dtype)
        x = stem(x)
        x = self.visual_encoder.layer1(x)
        x = self.visual_encoder.layer2(x)
        x = self.visual_encoder.layer3(x)
        x = self.visual_encoder.layer4(x)
        return x

    def model_inference(self, image):
        image_feat = self.encode_image(image)
        b, c, h, w = image_feat.shape
        x = image_feat.reshape(b, c, h * w).permute(2, 0, 1)
        
        x = F.linear(x, self.v_linear_weight, self.v_linear_bias)
        x = F.linear(x, self.c_linear_weight, self.c_linear_bias)
        image_features = x
        
        image_feature_, _ = self.clip_model.visual.attnpool(image_feat)
        
        image_features = image_features / image_features.norm(dim=-1, keepdim=True) #[49, 100, 1024]
        image_feature_ = image_feature_ / image_feature_.norm(dim=-1, keepdim=True) #[100, 1024]
        
        
        
      #   proto_set = self.prototype  # [cls,L,D]
      #   proto_features = get_eot_feature_2d(proto_set)  # [cls,D]
      # #  proto_feat = replicate_cls(proto_features, image.shape[0])  # [B,cls,D]
      #   proto_features = proto_features / proto_features.norm(dim=-1, keepdim=True) #[80, 1024]
        
        # print('self.text_features:', self.text_features.shape)
        # print('image_features',image_features.shape)
        # print('image_feature_',image_feature_.shape)
        #[80, 50, 77, 1024]

            
#         A_weight = torch.matmul(image_features.permute(1, 0, 2), self.text_features.permute(1,0)) * 2 #100 49 80
#         A_weight1 = F.softmax(A_weight, dim=1) 
#         temp_text_feats = torch.matmul( A_weight1.permute(0,2,1), image_features.permute(1, 0, 2))
        
        # 80 1024
        
        
        logit_scale = self.clip_model.logit_scale.exp()
        
        
#         for i in range(80):
#             A_weight = torch.matmul(image_features.permute(1, 0, 2), self.text_features[i].unsqueeze(1)) * 2 #100 49 1
#             A_weight1 = F.softmax(A_weight, dim=1)
#             A_weight2 = F.softmax(A_weight, dim=2)
#             print('A_weight1',A_weight1.shape)
#             print('A_weight2',A_weight2.shape)
#             print('self.text_features[i].unsqueeze(1)',self.text_features[i].unsqueeze(1).shape)
#             if i ==0:
#                 temp_img_feat = torch.matmul(A_weight2, self.text_features[i].unsqueeze(0))
#                 print('temp_img_feat',temp_img_feat.shape) #[100, 49, 1024])
#                # image_feature_temp = self.clip_model.visual.attnpool(temp_img_feat)
#                 image_feature_temp = temp_img_feat.mean(1) + temp_img_feat.max(1)[0]
#                 logits_ = logit_scale * image_feature_temp @ self.text_features[i].unsqueeze(0).t()
                
#             else:
#                 temp_img_feat = torch.matmul(A_weight2, self.text_features[i].unsqueeze(0)) #torch.cat((temp_img_feat, torch.matmul(A_weight2, self.text_features[i].unsqueeze(0))),0)
#                 #image_feature_temp = self.clip_model.visual.attnpool(temp_img_feat)
#                 image_feature_temp = temp_img_feat.mean(1) + temp_img_feat.max(1)[0]
#                 logits_temp = logit_scale * image_feature_temp @ self.text_features[i].unsqueeze(0).t()
#                 logits_ = torch.cat((logits_, logits_temp),1)
        
#         print('logits_',logits_.shape) 
        

       # logit_scale = self.clip_model.logit_scale.exp() # logit_scale = self.clip_model.logit_scale.exp()
      #  logits_ = logit_scale * image_feature_ @ self.text_features.t()   # B * C,  cls * C, = B * cls
      #  logits = logit_scale * image_features @ temp_text_feats.permute(0,2,1)    #  HW * B * C,  cls * C,  HW * B * cls
    
        for i in range(image_features.shape[0]):
            A = torch.matmul(image_feature_, self.text_features.permute(1,0)) *5  # 49 80
            A_up = F.softmax(A, dim=0) # 49 80
            temp_text_feats = torch.matmul(A_up.t(), image_feature_) # 80 1024
            
            # A = torch.matmul(image_features[:,i,:], temp_text_feats.permute(1,0)) *20  # 49 80
            # A_up = F.softmax(A, dim=0) # 49 80
            # temp_text_feats = torch.matmul(A_up.t(), image_features[:,i,:] ) # 80 1024
            
            
            if i == 0:
                logits_ = logit_scale * image_feature_[i,:] @ temp_text_feats.t()
                logits_ = logits_.unsqueeze(0)
            else:
                logits_temp = logit_scale * image_features[i,:] @ temp_text_feats.t()
                logits_temp = logits_temp.unsqueeze(0)
                logits_ = torch.cat((logits_,logits_temp),0)
        
        for i in range(image_features.shape[1]):
            A = torch.matmul(image_features[:,i,:], self.text_features.permute(1,0)) *5  # 49 80
            A_up = F.softmax(A, dim=0) # 49 80
            temp_text_feats = 0.9*self.text_features + 0.1 *torch.matmul(A_up.t(), image_features[:,i,:] ) # 80 1024
            
            # A = torch.matmul(image_features[:,i,:], temp_text_feats.permute(1,0)) *20  # 49 80
            # A_up = F.softmax(A, dim=0) # 49 80
            # temp_text_feats = torch.matmul(A_up.t(), image_features[:,i,:] ) # 80 1024
            
            
            if i == 0:
                logits = logit_scale * image_features[:,i,:] @ temp_text_feats.t()
                logits = logits.unsqueeze(1)
            else:
                logits_temp = logit_scale * image_features[:,i,:] @ temp_text_feats.t()
                logits_temp = logits_temp.unsqueeze(1)
                logits = torch.cat((logits,logits_temp),1)
 
        prob_spatial = torch.nn.functional.softmax(logits, dim=0)
        print(logits_.max())
        logits = torch.sum(logits * prob_spatial, dim=0)

        return logits_, logits, None, None

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")

        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            # output = self.model_inference(input)
            output, output_pos, image_features_, text_features_ = self.model_inference(input)
            self.evaluator.process(output, label, output_pos)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]


@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
    """Prompt ensembling."""

    # templates = IMAGENET_TEMPLATES
    templates = IMAGENET_TEMPLATES_SELECT

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        for params in clip_model.parameters():
            params.requires_grad_(False)

        # add custom-made prompt
        if cfg.DATASET.NAME != "ImageNet":
            self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]

        num_temp = len(self.templates)
        print(f"Prompt ensembling (n={num_temp})")

        mean_text_features = 0
        for i, temp in enumerate(self.templates):
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
            prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            mean_text_features = mean_text_features + text_features
        mean_text_features = mean_text_features / num_temp
        mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)

        self.text_features = mean_text_features
        self.clip_model = clip_model
