import torch
import torch.nn.functional as F
import clip
import torchvision.transforms as transforms

from mapper.attribute_list import ATTRIBUTE_LIST, TEMPLATES
from mapper.scripts.analyse_channel import channel_selector


# class CLIPLoss(torch.nn.Module):

#     def __init__(self, opts):
#         super(CLIPLoss, self).__init__()
#         self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
#         self.upsample = torch.nn.Upsample(scale_factor=7)
#         self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)

#     def forward(self, image, text): 
#         image = self.avg_pool(self.upsample(image))
#         similarity = 1 - self.model(image, text)[0] / 100
#         return similarity
    
    
class CLIPLoss(torch.nn.Module):
    def __init__(self, device, clip_model='ViT-B/32'):
        super(CLIPLoss, self).__init__()

        self.device = device
        self.model, clip_preprocess = clip.load(clip_model, device=self.device)

        self.clip_preprocess = clip_preprocess
        
        self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
                                              clip_preprocess.transforms[:2] +                                      # to match CLIP input scale assumptions
                                              clip_preprocess.transforms[4:])                                       # + skip convert PIL to tensor
        
        self.attribute_list = ATTRIBUTE_LIST
        self.attribute_loss = torch.nn.MSELoss()

        # self.model.requires_grad_(False)
        
        self.text_features = [[] for i in self.attribute_list]
        for idx, attr in enumerate(self.attribute_list):
            texts =  attr['value']
            
            max_len = max(len(vector) for vector in texts)
            
            for text in texts:
                text_feature = self.get_text_features(text, TEMPLATES)
                text_feature = torch.cat([text_feature, torch.zeros(max_len-len(text), text_feature.shape[-1]).cuda().detach()], dim=0).half()
                self.text_features[idx].append(text_feature)
                
            # self.text_features[idx] = torch.cat(self.text_features[idx], 0)
                
        
    def tokenize(self, strings: list):
        return clip.tokenize(strings).to(self.device)

    def encode_text(self, tokens: list) -> torch.Tensor:
        return self.model.encode_text(tokens)

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        images = self.preprocess(images).to(self.device)
        return self.model.encode_image(images)
    
    def get_text_features(self, texts: list, template=TEMPLATES, norm: bool = True) -> torch.Tensor:
        template_text = self.compose_text_with_templates(texts, template)

        tokens = clip.tokenize(template_text).to(self.device)

        text_features = self.encode_text(tokens).detach()

        if norm:
            text_features /= text_features.norm(dim=-1, keepdim=True)

        return text_features

    def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
        image_features = self.encode_images(img)
        
        if norm:
            image_features /= image_features.clone().norm(dim=-1, keepdim=True)

        return image_features

    def compose_text_with_templates(self, texts: list, template=TEMPLATES) -> list:
        return [template.format(text) for text in texts]
            
    def calc_similarity(self, image_features: torch.Tensor, text_features: torch.Tensor, template=TEMPLATES) -> torch.Tensor:

        # text_features  = self.get_text_features(texts, template)
        # image_features = self.get_image_features(img)

        similarity = 25. * image_features @ text_features.T
        # similarity = similarity - similarity.mean()

        return similarity
            
    def forward(self, src_img: torch.Tensor, manipulated_img: torch.Tensor, ref_img: torch.Tensor, target_attributes: list):
        clip_loss_ref = 0.0
        # clip_loss_src = 0.0
        clip_loss_src = []
        target_attr_num = len(target_attributes)
        all_attr_num = len(self.attribute_list)

        src_img_features = self.get_image_features(src_img)
        ref_img_features = self.get_image_features(ref_img)
        manipulated_img_features = self.get_image_features(manipulated_img)
        # src_img_features = self.preprocess(src_img).to(self.device)
        # ref_img_features = self.preprocess(ref_img).to(self.device)
        # manipulated_img_features = self.preprocess(manipulated_img).to(self.device)

        for idx, attr in enumerate(self.attribute_list):
            # texts =  attr['value']
            
            if attr['attr'] == 'hairstyle':
                num = 10
            else:
                num = 5
            
            # template_text = self.compose_text_with_templates(texts)
            # text_features = clip.tokenize(template_text).to(self.device)
            # text_features  = self.get_text_features(texts, TEMPLATES)
            text_features = self.text_features[idx]
            length = len(text_features)
            text_features = torch.cat(text_features, 0)
            logits_manipulated = self.calc_similarity(manipulated_img_features, text_features)
            logits_manipulated[logits_manipulated == 0] = float('-inf')
            # logits_manipulated = self.model(manipulated_img_features, text_features)[0]
            logits_manipulated = F.softmax(logits_manipulated.reshape((1, length, -1)), dim=-1)
            
            if idx in target_attributes:
                logits_ref = self.calc_similarity(ref_img_features, text_features)
                logits_ref[logits_ref == 0] = float('-inf')
                logits_ref = F.softmax(logits_ref.reshape((1, length, -1)), dim=-1)
                # logits_ref = self.model(ref_img_features, text_features)[0]
                # clip_loss_ref += self.attribute_loss(logits_manipulated, logits_ref)
                if (attr['attr'] == 'beard') and (torch.argmax(logits_ref[0][0]) == 0 or torch.argmax(logits_manipulated[0][0]) == 0):
                    clip_loss_ref += torch.abs(logits_manipulated - logits_ref)[:, 0].sum() * 3
                if (attr['attr'] == 'glasses') and (torch.argmax(logits_ref[0][0]) == 0 or torch.argmax(logits_manipulated[0][0]) == 0):
                    clip_loss_ref += torch.abs(logits_manipulated - logits_ref)[:, 0].sum() * 1.5
                elif (attr['attr'] == 'hat' or attr['attr'] == 'hair color'):
                    clip_loss_ref += torch.abs(logits_manipulated - logits_ref).sum() * 3
                else:
                    clip_loss_ref += torch.abs(logits_manipulated - logits_ref).sum()  # .sort(descending=True)[0][:, :num].mean() * attr['weight']
            else:
                logits_src = self.calc_similarity(src_img_features, text_features)
                logits_src[logits_src == 0] = float('-inf')
                logits_src = F.softmax(logits_src.reshape((1, length, -1)), dim=-1)
                # logits_src = self.model(src_img_features, text_features)[0]
                # clip_loss_src += self.attribute_loss(logits_manipulated, logits_src)
                if (attr['attr'] == 'beard') and (torch.argmax(logits_src[0][0]) == 0 or torch.argmax(logits_manipulated[0][0]) == 0):
                    # clip_loss_src += torch.abs(logits_manipulated - logits_src)[:, 0].sum()
                    clip_loss_src.append(torch.abs(logits_manipulated - logits_src)[:, 0].sum() * 3)
                if (attr['attr'] == 'glasses') and (torch.argmax(logits_src[0][0]) == 0 or torch.argmax(logits_manipulated[0][0]) == 0):
                    # clip_loss_src += torch.abs(logits_manipulated - logits_src)[:, 0].sum()
                    clip_loss_src.append(torch.abs(logits_manipulated - logits_src)[:, 0].sum() * 1.5)
                elif (attr['attr'] == 'hat' or attr['attr'] == 'hair color'):
                    clip_loss_src.append(torch.abs(logits_manipulated - logits_src).sum() * 3)
                else:
                    # clip_loss_src += torch.abs(logits_manipulated - logits_src).sum()  # .sort(descending=True)[0][:, :num].mean() * attr['weight']
                    clip_loss_src.append(torch.abs(logits_manipulated - logits_src).sum())

        clip_loss_src = torch.tensor(sorted(clip_loss_src, reverse=True)[:3]).mean()

        return  clip_loss_ref / target_attr_num + clip_loss_src # / (all_attr_num - target_attr_num)
            
            
class MaskLoss(torch.nn.Module):
    def __init__(self, opts):
        super(MaskLoss, self).__init__()
        
        self.opts = opts
        self.attribute_list = ATTRIBUTE_LIST
        
        # correlation = torch.load(opts.correlation_path)
        # corr = torch.stack(correlation, 0).abs().transpose(0, 1)
        # # ric = corr / corr.sum(dim=1, keepdim=True)
        # self.rim = corr / corr.sum(dim=0, keepdim=True).clamp(0.000001).detach()
        # self.rim[:, 1:20:3, :] = 0
        
        # self.num_channels = 4000

    def forward(self, output_probs):
        loss = 0.0
        
        # 最大化预测概率
        max_prob = torch.max(output_probs, dim=0).values
        logits = torch.sort(torch.log(max_prob).reshape((-1,)), descending=False)[0]
        loss += -torch.mean(logits[:logits.shape[0]])
        # loss += torch.mean(output_probs)

        # 预测为每个类别的通道应该在对应区域的前4000之内
        # loss_region = 0.0
        # cnt = 0
        # for idx, attr in enumerate(self.attribute_list):
        #     if not 'face_region' in attr:
        #         continue
        #     values, indices = self.rim[attr['face_region']].reshape(-1).topk(self.num_channels)
        #     mask_i = (output_probs.argmax(dim=0) == idx)
        #     max_prob_i = torch.where(mask_i, max_prob, 0).flatten()
        #     max_prob_i[indices] = 0.
        #     loss_region += max_prob_i.sum()
        #     cnt += 1
        # loss += loss_region / cnt * self.opts.mask_loss_lambda
        
        return loss
            