import os
import torch
import numpy as np
from transformers import AutoImageProcessor, AutoModel
import torchvision.transforms.functional as tvF
from torchvision import transforms
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import torch.nn.functional as F
from src.dift.dift_sd import SDFeaturizer
import math
from src.clip import clip


def upsample_imgfeat(img_feats, img_size=(224,224)):
    b, nv, hw, c = img_feats.size(0), img_feats.size(1), img_feats.size(2), img_feats.size(3)
    img_feats = img_feats.reshape(b * nv, hw, c)
    
    upsample = torch.nn.Upsample(size=img_size, mode='bilinear')  # nearest, bilinear
    avgpool = torch.nn.AvgPool2d(6, 1, 0)
    padding = torch.nn.ReplicationPad2d((2, 3, 2, 3))
    
    img_feats = img_feats.permute(0, 2, 1).reshape(-1, c, int(hw**0.5), int(hw**0.5))
    img_feats = avgpool(padding(img_feats))
    output = upsample(img_feats)
    return output


class DiftEncoder(torch.nn.Module):
    def __init__(self, use_cache, out_channels=None):
        super(DiftEncoder, self).__init__()
        self.dift = SDFeaturizer(null_prompt='')
        self.dift_dim=1280
        # for param in self.dift.parameters():
        #     param.requires_grad = False
        self.out_channels = 96
        self.linear = torch.nn.Linear(self.dift_dim, self.out_channels)
        self.cache = {}
    
    def encode_img(self, img):
        img = img.permute(0,3,1,2)
        B,C,W,H = img.shape
        img = (img-0.5)*2
        img_feats = []
        for i in range(B):
            img_feat = self.dift.forward(img[i], prompt="")
            fb,fc,fw,fh = img_feat.shape
            img_feat = img_feat.permute(0,2,3,1).reshape(fw*fh,fc)
            img_feat = img_feat.unsqueeze(dim=0)            
            img_feats.append(img_feat)
        img_feats = torch.cat(img_feats)
        return img_feats
    
    def forward(self, id, img, data_path=None):
        B,W,H,C = img.shape
        with torch.no_grad():
            if id !=-1:
                if id in self.cache:
                    img_feats = self.cache[id]
                else:
                    img_feats = self.encode_img(img)
                    self.cache[id] = img_feats
            else:
                img_feats = self.encode_img(img)
            
        img_feats = self.linear(img_feats)
        up_img_feats = upsample_imgfeat(img_feats.unsqueeze(dim=1), img_size=(W,H))
        
        return up_img_feats, None

class CLIPDepthEncoder(torch.nn.Module):
    def __init__(self, use_cache, out_channels=None, tokenW=57, tokenH=57):
        super(CLIPDepthEncoder, self).__init__()
        self.out_channels=512
    
    def forward(self, id, feat, data_path=None):
        
        b, c,h,w = feat.size(0), feat.size(1), feat.size(2), feat.size(3)
        output = feat.float()
        upsample = torch.nn.Upsample(size=224, mode='bilinear')  # nearest, bilinear
        avgpool = torch.nn.AvgPool2d(6,1,0)
        padding = torch.nn.ReplicationPad2d([2,3,2,3])
        output = avgpool(padding(output))
        output = upsample(output)
        
        return output, 0
    
class ImageEncoderClip(torch.nn.Module):
    def __init__(self, use_cache, out_channels=None):
        super(ImageEncoderClip, self).__init__()
        model, preprocess = clip.load("ViT-B/16")
        self.model = model
        for param in self.model.parameters():
            param.requires_grad = False
        self.preprocess = preprocess
        self.hidden_size = 512
        self.out_channels = 96
        self.linear = torch.nn.Linear(self.hidden_size, self.out_channels)
        self.cache = {}
        
    
    def forward(self, img):
        _, x = self.model.encode_image(img)
        x = x / x.norm(dim=-1, keepdim=True)
        return x
    def encode_img(self, image):
        """_summary_

        Args:
            image (_type_): B*3*W*H

        Returns:
            _type_: B*W*H*512
        """
        image = image.permute(0,3,1,2)
        B,C,W,H = image.shape
        if not (W==224 and H ==224):
            image = F.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)
        
        _, img_feat = self.model.encode_image(image) # img_feat B*196*512
        return img_feat
    
    def encode_text(self, text):
        """_summary_

        Args:
            text (_type_): list of text e.g. text=["a goegously dressed woman", "a purple sleeveness dress", "bouquet of pink flowers"]
        """
        text = clip.tokenize(text).to(self.device)
        text_features = self.model.encode_text(text)
        return text_features
    
    def forward(self, id, img, data_path=None):
        B,W,H,C = img.shape
        if id !=-1:
            if id in self.cache:
                # print("use cache")
                img_feat = self.cache[id]
            else:
                img_feat = self.encode_img(img).float()
                self.cache[id] = img_feat
        else:
            img_feat = self.encode_img(img).float()
        
        img_feat = self.linear(img_feat)
        up_img_feat = upsample_imgfeat(img_feat.unsqueeze(dim=0), img_size=(W,H))
        return up_img_feat, None
    
class Dinov2Encoder(torch.nn.Module):
    def __init__(self, use_cache, out_channels=None, tokenW=57, tokenH=57):
        super(Dinov2Encoder, self).__init__()
        self.use_cache = use_cache
        self.hidden_size = 768
        self.out_channels = out_channels if not(out_channels is None) else self.hidden_size//8 # 96
        if not use_cache:
            self.dinov2 = AutoModel.from_pretrained('facebook/dinov2-base')
            for param in self.dinov2.parameters():
                param.requires_grad = False
        self.conv1 = torch.nn.Conv2d(self.hidden_size, self.out_channels, (1,1))
        self.decode = torch.nn.Conv2d(self.out_channels, self.hidden_size, (1,1))
        self.relu = torch.nn.ReLU()
        self.cache = {}
        # self.feature_enhance = torch.nn.Sequential(
        #     torch.nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
        #     torch.nn.BatchNorm2d(self.out_channels),
        #     torch.nn.ReLU(inplace=True),
        #     torch.nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
        #     torch.nn.BatchNorm2d(self.out_channels),
        #     torch.nn.ReLU(inplace=True),
        # )
        
    def apply_transform(self, images_tensor):
        B,W,H,C = images_tensor.shape
        images_tensor = images_tensor.permute(0,3,1,2)
        img_size = W if W!=800 else 798
        resized_tensor = tvF.resize(images_tensor, size=img_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)

        normalized_tensor = tvF.normalize(resized_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        
        return normalized_tensor
    
    def encode_img(self, img):
        B,W,H,d = img.shape
        img = self.apply_transform(img)
        with torch.no_grad():
            outputs = self.dinov2(img, output_hidden_states=False,output_attentions=False)
        patch_embeddings = outputs.last_hidden_state[:,1:,:]
        ebd_wh = int(math.sqrt(patch_embeddings.shape[1]))
        patch_embeddings = patch_embeddings.reshape(-1, ebd_wh, ebd_wh, self.hidden_size)
        img_feat = patch_embeddings.permute(0,3,1,2)
        return img_feat
    
    def forward(self, id, img, data_path=None):
        """_summary_

        Args:
            img (_type_): [B,W,H,3]
        """
        # print("ffffffffffffffforward", self.use_cache, id)
        B,W,H,C = img.shape
        if img.shape[1]==768:
            img_feat = img
        else:
            if id !=-1:
                if id in self.cache:
                    # print("use cache")
                    img_feat = self.cache[id]
                else:
                    img_feat = self.encode_img(img)
                    self.cache[id] = img_feat
            else:
                img_feat = self.encode_img(img)
                
        x = self.conv1(img_feat)
        img_feat_pred = self.decode(x)
        loss = F.mse_loss(img_feat_pred, img_feat)
        # x = self.feature_enhance(x)+x
        x = torch.nn.functional.interpolate(x, size=(W,H), mode="bilinear", align_corners=False)
        
        return x, loss

class SAMEncoder(torch.nn.Module):
    def __init__(self, use_cache):
        super(SAMEncoder, self).__init__()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.use_cache = use_cache
        self.hidden_size = 256
        self.out_channels = 96 # 96
        
        sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"  # 替换为您本地的模型权重路径
        model_type = "vit_h"
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
        self.predictor = SamPredictor(sam)
        for param in self.predictor.model.parameters():
            param.requires_grad = False

        self.conv1 = torch.nn.Conv2d(self.hidden_size, self.out_channels, (1,1))
        self.decode = torch.nn.Conv2d(self.out_channels, self.hidden_size, (1,1))
        self.relu = torch.nn.ReLU()
        
        self.cache = {}
    def encode_img(self, img):
        B,W,H,d = img.shape
        img = (img.cpu().numpy()*255).astype(np.uint8)
        img_feat = []
        with torch.no_grad():
            v_num=img.shape[0]
            for view in range(v_num):
                self.predictor.set_image(img[view])
                img_emb = self.predictor.get_image_embedding()
                img_feat.append(img_emb)
        img_feat = torch.cat(img_feat, dim=0) 
        return img_feat
    
    def forward(self, id, img):
        """_summary_

        Args:
            img (_type_): [B,W,H,3]
        """
        B,W,H,d = img.shape
        if self.use_cache and id !=-1:
            if id in self.cache:
                img_feat = self.cache[id]
            else:
                img_feat = self.encode_img(img)
                self.cache[id] = img_feat
        else:
            img_feat = self.encode_img(img)
        
        x = self.conv1(img_feat)
        img_feat_pred = self.decode(x)
        loss = F.mse_loss(img_feat_pred, img_feat)
        x = torch.nn.functional.interpolate(x, size=(W,H), mode="bilinear", align_corners=False)
        return x, loss      


class ImageEncoder(torch.nn.Module):
    def __init__(self, img_encoder, use_cache):
        super(ImageEncoder, self).__init__()
        self.img_encoder = img_encoder
        self.use_cache = use_cache
        if self.img_encoder == "dinov2":
            self.encoder = Dinov2Encoder(use_cache)
            self.out_dim = self.encoder.out_channels
        elif self.img_encoder == "sam": 
            self.encoder = SAMEncoder(use_cache)
            self.out_dim = self.encoder.out_channels
        elif self.img_encoder == "dift":
            self.encoder = DiftEncoder(use_cache)
            self.out_dim = self.encoder.out_channels
        elif self.img_encoder == "clip":
            self.encoder = ImageEncoderClip(use_cache)
            self.out_dim = self.encoder.out_channels
        elif self.img_encoder == "clip_depth":
            self.encoder = CLIPDepthEncoder(use_cache)
            self.out_dim = self.encoder.out_channels
        elif self.img_encoder == "dinov2_sam":
            self.dinov2 = Dinov2Encoder(use_cache)
            self.sam = SAMEncoder(use_cache)
            self.out_dim = self.dinov2.out_channels + self.sam.out_channels
        
    def forward(self, id, x, data_path=None):
        """_summary_

        Args:
            x (_type_): [B,W,H,3]

        Returns:
            _type_: [B,W,H,D]
        """
        if self.img_encoder == "dinov2_sam":
            # print("dinov2_sam")
            img_feat_dinov2, loss1 = self.dinov2(id, x, data_path)
            img_feat_sam, loss2 = self.sam(id, x, data_path)
            img_feat = torch.cat([img_feat_dinov2, img_feat_sam], dim=1)
            loss = loss1+loss2
        else:
            img_feat, loss = self.encoder(id, x, data_path)
        return img_feat, loss
        
    
if __name__ == "__main__":
    model = ImageEncoder()