from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import clip

from models.adapter import SimplifiedAdapter
from models.foldingnet import SkipVariationalEncoder


class UniL(nn.Module):
    def __init__(self, pc_encoder=SkipVariationalEncoder(512), img_dim=512, text_dim=512, pc_dim=1088, embed_dim=512):
        super(UniL, self).__init__()
        clip_model, preprocess = clip.load("ViT-B/32", device='cpu')
        self.visual_encoder = deepcopy(clip_model.visual)

        self.point_encoder = pc_encoder
        self.adapter = SimplifiedAdapter(num_views=10, in_features=img_dim)

        self.image_projection = nn.Parameter(torch.empty(img_dim, embed_dim))
        self.text_projection = nn.Parameter(torch.empty(text_dim, embed_dim))
        self.pc_projection = nn.Parameter(torch.empty(pc_dim, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        nn.init.normal_(self.image_projection, std=img_dim ** -0.5)
        nn.init.normal_(self.text_projection, std=text_dim ** -0.5)
        nn.init.normal_(self.pc_projection, std=pc_dim ** -0.5)
        self.freeze_visual()

    def freeze_visual(self):
        print("===> Freezing CLIP's visual encoder...")
        for name, param in self.visual_encoder.named_parameters():
            param.requires_grad = False
        print("===> Freezing complete!")

    def encode_image(self, img):
        b, n, c, h, w = img.size()
        img = img.reshape(b * n, c, h, w)
        img_feat = self.visual_encoder(img)  # [B * num_views, 512]
        img_feat = self.adapter(img_feat)  # [B, 512]
        img_embed = img_feat @ self.image_projection  # [B, 512]
        return img_embed

    def encode_text(self, text):
        text_embed_all = []
        for i in range(text.shape[0]):
            text_embed = text[i] @ self.text_projection  # [64, 512] * [512, 512] -> [64, 512]
            text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)  # [64, 512]
            text_embed = text_embed.mean(dim=0)  # [512]
            text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)  # [512]
            text_embed_all.append(text_embed)
        text_embed_all = torch.stack(text_embed_all)  # [B, 512]
        return text_embed_all

    def encode_pc(self, pc):
        pc_feat = self.point_encoder(pc).squeeze()  # [B, pc_dims]
        pc_embed = pc_feat @ self.pc_projection  # [B, 512]
        return pc_embed

    def forward(self, pc, text, depth):
        """
        :param pc: point cloud with size of [B, N, 3]
        :param text: text tensor encoded by CLIP with size of [B, 64, 512]
        :param depth: depth map with size of [B, num_views, 3, 224, 224]
        :return: logits map with size of [B, B]
        """
        pc_embed = self.encode_pc(pc)
        text_embed = self.encode_text(text)
        image_embed = self.encode_image(depth)

        # normalized features
        pc_features = pc_embed / pc_embed.norm(p=2, dim=1, keepdim=True)
        text_features = text_embed / text_embed.norm(p=2, dim=1, keepdim=True)
        image_features = image_embed / image_embed.norm(p=2, dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_pc_text = logit_scale * pc_features @ text_features.T
        logits_per_pc_image = logit_scale * pc_features @ image_features.T

        return logits_per_pc_text, logits_per_pc_image


def model_state_dict(model):
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())


if __name__ == '__main__':
    model = UniL()
    total = sum(p.numel() for p in model.parameters())
    print("Total params: %.2fM" % (total / 1e6))

    for name, param in model.named_parameters():
        # if 'linear' in name:
        #     param.requires_grad = False
        print(name, param.requires_grad)

    p = torch.randn(8, 1024, 3)
    t = torch.randn(8, 64, 512)
    d = torch.randn(8, 10, 3, 224, 224)

    logits_pc_text, logits_pc_image = model(p, t, d)
    print(logits_pc_text.shape)
