import torch
from torch.nn import Module
import clip
from mapper.attribute_list import ATTRIBUTE_LIST
from mapper.scripts.analyse_channel import channel_selector
import torchvision.transforms as transforms


class LatentMapper(Module): 
    def __init__(self, opts):
        super(LatentMapper, self).__init__()
        self.opts = opts
        self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device)
        self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
        self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
        self.hairstyle_cut_flag = 0
        self.color_cut_flag = 0

        self.layer_idx, self.channel_idx = channel_selector(num_channels=2500, correlation_path=self.opts.correlation_path)
        
        self.init_mask = torch.ones((1, 26, 512)).to(opts.device)
        # zero out toRGB layers and the superresolution module
        self.init_mask[:, 1:20:3, :] = 0
        self.init_mask[:, 20:, :] = 0
        
        raw_mask = torch.cat([torch.zeros([len(ATTRIBUTE_LIST), 26, 512]), torch.ones([1, 26, 512])], 0)
        raw_mask[0][self.layer_idx, self.channel_idx] = 1.8
        raw_mask[-1] = 1. - (raw_mask[0]>0).float()
        self.style_mask = torch.nn.Parameter(raw_mask)
        # alpha = torch.ones([len(ATTRIBUTE_LIST), 1, 1])  # * 1.25
        # # alpha[0] = 1.5
        # alpha[-1] = 1.5
        # self.alpha = torch.nn.Parameter(alpha)
        
    def forward(self, src_latent, ref_latent, selected_attributes):
        logits = torch.softmax(self.style_mask, dim=0)
        mask = logits[selected_attributes].sum(dim=0, keepdim=True).clamp(0, 1)
        mask = mask * self.init_mask.detach()
        # out = x * (1. - mask) + hairstyle_latent * mask
        out = src_latent + (ref_latent-src_latent) * mask # * self.alpha[selected_attributes]
            
        return out, logits