import torch
import torchvision
from torch import nn
from torchvision import transforms
from BLIP.models.blip_itm import blip_itm
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import threestudio

@threestudio.register("blip-guidance")
class blip_sim_loss(nn.Module):
    def __init__(self, device=None,hair_prompt=None):
        super().__init__()
        self.image_size = 224
        self.device = device
        # self.view_dirs = ['back']*3+['right side']*6+['front']*6+['left side']*6+['back']*4 #['back', 'right side','front', 'left side'] # 'front', 
        # self.prompt = hair_prompt
        # self.blip_model, self.blip_text_z, self.blip_text_z_neg, self.blip_text_str = self.cal_blip_text_embeddings()
        
        self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        self.seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

        # # extract feature
        # self.transform = transforms.Compose([
        #         transforms.Resize((self.image_size,self.image_size)),
        #         transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        #         ])   
        

    # both blip and clip work well
    def cal_blip_text_embeddings(self):
        path = 'pretrained_model/model_base.pth'
        vit  = 'base'
        model = blip_itm(pretrained=path, image_size=self.image_size, vit=vit)
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)
        model = model.to(device=self.device)
        ref_text = self.prompt
        text_z = []
        blip_text_str = []
        for d in self.view_dirs:
            text = ref_text.format(d)
            # print(text)
            blip_text_str.append(text)
            with torch.no_grad():
                text_output = model.get_text_feature(text, device=self.device)
            text_z.append(text_output)  
        text_neg = 'watermark, blur, face, beard, eyes, mouth, earrings, low resolution, dirty, tattoos, floating artifacts, ugly, tiling, poorly drawn face'
        text_z_neg = model.get_text_feature(text_neg, device=self.device).detach()
        return model, text_z, text_z_neg, blip_text_str

    def clipseg_mask(self, prompts, image):
        inputs = self.processor(text=prompts, images=[image] * len([prompts]), padding="max_length", return_tensors="pt")
        inputs_cuda = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            # print(**inputs)
            outputs = self.seg_model(**inputs_cuda)
        preds = outputs.logits
        return torch.sigmoid(preds)>0.5
    
    def forward(self, pred_rgb,mask_temp): #pred_rgb: BHWC 0-1
        # segment the hair region        
        mask_im=[]
        for rgb in pred_rgb:
            rgb = transforms.functional.to_pil_image(rgb.permute(2,0,1))
            #FIXME rgb img range
            mask_im.append(self.clipseg_mask(["hair"], rgb))
        mask_im = torch.stack(mask_im,dim=0)
        mask_im = torch.nn.functional.interpolate(mask_im.float(), (pred_rgb.size(1), pred_rgb.size(2)), mode='bilinear', antialias=True).to(pred_rgb.device)
        pred_rgb = pred_rgb.permute(0,3,1,2)
        pred_rgb = pred_rgb*mask_im.detach()
        torchvision.utils.save_image(pred_rgb.data, 'hair_wom.png', normalize=True)
        return pred_rgb
        # return torch.cat((pred_rgb,mask_im),dim=1)
        # torchvision.utils.save_image(pred_rgb.data, 'hair_wom.png', normalize=True)
        # pred_rgb = pred_rgb*mask_temp.permute(0,3,1,2)
        # torchvision.utils.save_image(pred_rgb.data, 'hair_wm.png', normalize=True)

        # augmented_rgb = self.transform(pred_rgb)
        # text_feature = torch.cat(self.blip_text_z,dim=0)
        # img_feature = self.blip_model.get_img_feature(augmented_rgb)

        # # calculate loss
        # cos_sim = img_feature @ text_feature.t()
        # cos_sim_neg = img_feature @ self.blip_text_z_neg.t()

        # # print(cos_sim)
        # # print(cos_sim_neg)
        # # loss = - torch.diag(cos_sim)
        # loss = cos_sim_neg[:,0] - torch.diag(cos_sim)
        # # exit()
        # return loss.mean(), pred_rgb

