import torch
import torch.nn as nn
import numpy as np
import clip
import gc
import sys
import os
import yaml
import time
import torch.nn.functional as F
from PIL import Image
from models.softgroup.model import SoftGroup

from lib.ap_helper.ap_helper_fcos import parse_predictions

from torch.profiler import record_function
from utils.util import cuda_cast
from models.objectrenderer.object_renderer import ObjectRenderer
from models.long_clip.model import longclip


class ThreeLayerMLP(nn.Module):
    """A 3-layer MLP with normalization and dropout."""

    def __init__(self, dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim, 1, bias=False),
            # nn.LayerNorm(dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(dim, dim, 1, bias=False),
            # nn.LayerNorm(dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(dim, out_dim, 1)
        )

    def forward(self, x):
        """Forward pass, x can be (B, dim, N)."""
        return self.net(x)

class JointNet(nn.Module):
    def __init__(self, num_class, class_name,
                 input_feature_dim=0, width=1,
                 num_proposal=128, num_target=32, num_rec_other=16, num_locals=-1, vote_factor=1, sampling="vote_fps",
                 no_caption=False, use_topdown=False, query_mode="corner",
                 use_lang_classifier=True, use_bidir=False, no_reference=False,
                 emb_size=300, hidden_size=256, args=None, cfg=None, vocabulary=None):
        super().__init__()
        self.num_class = num_class
        self.class_name = class_name
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.vote_factor = vote_factor
        self.sampling = sampling
        self.use_lang_classifier = use_lang_classifier
        self.use_bidir = use_bidir
        self.no_reference = no_reference
        self.no_caption = no_caption
        self.num_target = num_target
        self.num_other = num_proposal - num_target
        self.num_rec_other = min(num_rec_other, self.num_other)
        # self.vocab_size = 3235 if args.dataset == "nr3d" else len(vocabulary["idx2word"])
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.args = args
        self.cfg = cfg
        # print(len(vocabulary["idx2word"]))

        if args.pretrain_model_on:
            if args.pretrain_model == "softgroup":
                # --------- SoftGroup PROPOSAL GENERATION ---------
                self.softgroup = SoftGroup(
                    **cfg.model.softgroup,
                    num_proposal=num_proposal)
                for param in self.softgroup.parameters():
                    param.requires_grad = False
            
        self.clip_model = clip.load(cfg.model.clip_model)[0]
        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.longclip_model = longclip.load("models/long_clip/checkpoints/LongCLIP-L/longclip-L.pt")[0]
        for param in self.longclip_model.parameters():
            param.requires_grad = False

        self.lang_cls = ThreeLayerMLP(768, self.num_class)
        for param in self.lang_cls.parameters():
            param.requires_grad = False

        self.objectrenderer = ObjectRenderer(**cfg.model.object_renderer)

        if args.distribute:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    # @cuda_cast
    def forward(self, data_dict, use_tf=True, is_eval=False):
        """ Forward pass of the network

        Args:
            data_dict: dict
                {
                    point_clouds,
                    lang_feat
                }

                point_clouds: Variable(torch.cuda.FloatTensor)
                    (B, N, 3 + input_channels) tensor
                    Point cloud to run predicts on
                    Each point in the point-cloud MUST
                    be formated as (x, y, z, features...)
        Returns:
            end_points: dict
        """

        if not self.args.nodetect:
            with record_function("pretrain model"):
                if self.args.pretrain_model_on:
                    if self.args.pretrain_model == "softgroup":
                        pred_instances = self.softgroup(data_dict["batch_idxs"],
                                                data_dict["voxel_coords"],
                                                data_dict["p2v_map"],
                                                data_dict["v2p_map"],
                                                data_dict["coords_float"],
                                                data_dict["feats"],
                                                data_dict["spatial_shape"],
                                                data_dict["batch_size"],
                                                data_dict["scan_ids"])
                        data_dict.update(pred_instances)
        else:
            data_dict["proposals_feat"] = data_dict["prop_feat"]
            data_dict["objectness_pred"] = data_dict["objectness"]
            data_dict["pred_masks"] = data_dict["prop"]
            data_dict["sem_cls_scores"] = data_dict["prop_sem"]
        
        longclip_tokens = data_dict["longclip_token"].flatten(start_dim=0, end_dim=1)
        lang_feat = self.longclip_model.encode_text(longclip_tokens)
        data_dict["lang_scores"] = self.lang_cls(lang_feat.float().unsqueeze(-1)).squeeze(dim=-1)   # 文本类别预测

        # text feats
        clip_tokens = self.get_token(data_dict)
        data_dict["clip_text_feat"] = self.clip_model.encode_text(clip_tokens)
        data_dict["longclip_text_feat"] = lang_feat
        
        # visual feats
        # clip
        views = len(self.cfg.model.object_renderer.eye)
        rendered_imgs, _ = self.objectrenderer(data_dict, None)
        img_feats = self.clip_model.encode_image(rendered_imgs.permute(dims=(0, 3, 1, 2)))
        data_dict["clip_img_feat"] = self.get_proposal_feat(img_feats, data_dict["objectness_pred"], views)

        # longclip
        vis_topk = 15
        context_img_feats = self.get_context_imgs_feat(data_dict, data_dict["img_feat_path"], vis_topk)
        data_dict["longclip_img_feat"] = self.get_proposal_feat(context_img_feats, data_dict["objectness_pred"], vis_topk)

        data_dict["coarse_weights"] = self.get_weight(data_dict)

        return data_dict

    def get_weight(self, data_dict):
        B, max_num = data_dict["clip_token"].shape[:2]
        objectness = data_dict["objectness_pred"].unsqueeze(1).expand(-1, max_num, -1)
        lang_cls = data_dict["lang_scores"]

        similarity_matrix_a = torch.einsum('bmd,bnd->bmn', data_dict["clip_text_feat"].reshape(B, max_num, -1), 
                                           data_dict["clip_img_feat"]).masked_fill(objectness == 0, float('-inf')).softmax(-1)
        similarity_matrix_b = torch.einsum('bmd,bnd->bmn', data_dict["longclip_text_feat"].reshape(B, max_num, -1), 
                                           data_dict["longclip_img_feat"]).masked_fill(objectness == 0, float('-inf')).softmax(-1)

        similarity_matrix_c = torch.einsum('bmd,bnd->bmn', lang_cls.reshape(B, max_num, -1).float(), 
                                            data_dict["sem_cls_scores"]).masked_fill(objectness == 0, float('-inf')).softmax(-1)


        similarity_matrix = similarity_matrix_a * similarity_matrix_b * similarity_matrix_c
        scales = 1 / similarity_matrix.sum(dim=-1)
        similarity_matrix_nrom = similarity_matrix * scales[:,:,None]
        weight = similarity_matrix_nrom
        
        return weight
    
    def get_token(self, data_dict):
        clip_tokens = []
        target_words = data_dict["target_words"]
        mod_words = data_dict["mod_words"]
        for idx,cls_name in enumerate(target_words):
            clip_description = ''
            for ids,mod_w in enumerate(mod_words[idx]):
                clip_description = clip_description + ' ' + mod_w
            clip_description = clip_description[1:] + ' ' + cls_name
            clip_token = clip.tokenize(clip_description, truncate=True)[0].to(data_dict["lang_scores"].device)
            clip_tokens.append(clip_token)

        return torch.stack(clip_tokens, 0)

    def get_context_imgs_feat(self, data_dict, feat_path, views):
        context_imgs_feat = []
        for i in range(len(feat_path)):
            visibility_mask = data_dict["visibility_mask"][i]
            prop = data_dict["pred_masks"][i]
            img_feat = torch.from_numpy(torch.load(feat_path[i])).to(prop.device)
            for j in range(len(prop)):
                mask = prop[j] > 0
                visibility_score = visibility_mask[:,mask].sum(dim=-1)
                if visibility_score.shape[0] >= views:
                    topk_visibility = torch.topk(visibility_score, k=views)[1]
                    context_img_feat = img_feat[topk_visibility]
                else:
                    last_element = img_feat[-1].unsqueeze(0).repeat(views - visibility_score.shape[0], 1)
                    context_img_feat = torch.cat([img_feat, last_element], dim=0)

                context_imgs_feat.append(context_img_feat)
            del img_feat
            gc.collect()
            torch.cuda.empty_cache()

        context_imgs_feat = torch.cat(context_imgs_feat)
        return context_imgs_feat
    
    def get_context_imgs(self, data_dict, img_path, views):
        context_imgs = []
        for i in range(len(img_path)):
            visibility_mask = data_dict["visibility_mask"][i]
            prop = data_dict["pred_masks"][i]
            imgs = torch.from_numpy(torch.load(img_path[i])).to(prop.device)
            for j in range(len(prop)):
                mask = prop[j] > 0
                visibility_score = visibility_mask[:,mask].sum(dim=-1)
                topk_visibility = torch.topk(visibility_score, k=views)[1]
                context_img = imgs[topk_visibility]
                context_imgs.append(context_img)
            del imgs
            gc.collect()
            torch.cuda.empty_cache()

        context_imgs = torch.cat(context_imgs)
        return context_imgs
    
    def get_proposal_feat(self, img_feats, objectness_pred, views):
        clip_feat = torch.nn.functional.avg_pool1d(
            img_feats.permute(1, 0), kernel_size=views, stride=views
        ).permute(1, 0)
        clip_feat = self.convert_sparse_tensor_to_dense(clip_feat, objectness_pred, self.num_proposal)
        return clip_feat

    def convert_sparse_tensor_to_dense(self, sparse_info, objectness_pred, num_proposal):
        dense_aabb_info = torch.zeros(
            size=(objectness_pred.shape[0], num_proposal) + sparse_info.shape[1:],
            dtype=sparse_info.dtype, device=sparse_info.device
        )
        total_prop = 0
        for i in range(objectness_pred.shape[0]):
            prop_num = objectness_pred[i].long().sum().item()
            aabb_start_idx = total_prop
            aabb_end_idx = total_prop + prop_num
            total_prop += prop_num
            dense_aabb_info[i][0: prop_num] = sparse_info[aabb_start_idx:aabb_end_idx]
        return dense_aabb_info
    
    def get_max_min_norm(self, sim_matrix, objectness):
        B, max_num = objectness.shape[:2]
        sim_nrom = torch.zeros_like(sim_matrix, device=sim_matrix.device)
        for i in range(B):
            for j in range(max_num):
                prop_num = int(objectness[i][j].sum().item())
                mask_min = sim_matrix[i][j][:prop_num].min()
                scale = sim_matrix[i][j][:prop_num].max() - sim_matrix[i][j][:prop_num].min()
                sim_nrom[i][j][:prop_num] = (sim_matrix[i][j][:prop_num] - mask_min) / scale
        
        return sim_nrom




    def get_templaes(self):
        return [
            'a bad photo of a {}.',
            'a photo of many {}.',
            'a sculpture of a {}.',
            'a photo of the hard to see {}.',
            'a low resolution photo of the {}.',
            'a rendering of a {}.',
            'graffiti of a {}.',
            'a bad photo of the {}.',
            'a cropped photo of the {}.',
            'a tattoo of a {}.',
            'the embroidered {}.',
            'a photo of a hard to see {}.',
            'a bright photo of a {}.',
            'a photo of a clean {}.',
            'a photo of a dirty {}.',
            'a dark photo of the {}.',
            'a drawing of a {}.',
            'a photo of my {}.',
            'the plastic {}.',
            'a photo of the cool {}.',
            'a close-up photo of a {}.',
            'a black and white photo of the {}.',
            'a painting of the {}.',
            'a painting of a {}.',
            'a pixelated photo of the {}.',
            'a sculpture of the {}.',
            'a bright photo of the {}.',
            'a cropped photo of a {}.',
            'a plastic {}.',
            'a photo of the dirty {}.',
            'a jpeg corrupted photo of a {}.',
            'a blurry photo of the {}.',
            'a photo of the {}.',
            'a good photo of the {}.',
            'a rendering of the {}.',
            'a {} in a video game.',
            'a photo of one {}.',
            'a doodle of a {}.',
            'a close-up photo of the {}.',
            'a photo of a {}.',
            'the origami {}.',
            'the {} in a video game.',
            'a sketch of a {}.',
            'a doodle of the {}.',
            'a origami {}.',
            'a low resolution photo of a {}.',
            'the toy {}.',
            'a rendition of the {}.',
            'a photo of the clean {}.',
            'a photo of a large {}.',
            'a rendition of a {}.',
            'a photo of a nice {}.',
            'a photo of a weird {}.',
            'a blurry photo of a {}.',
            'a cartoon {}.',
            'art of a {}.',
            'a sketch of the {}.',
            'a embroidered {}.',
            'a pixelated photo of a {}.',
            'itap of the {}.',
            'a jpeg corrupted photo of the {}.',
            'a good photo of a {}.',
            'a plushie {}.',
            'a photo of the nice {}.',
            'a photo of the small {}.',
            'a photo of the weird {}.',
            'the cartoon {}.',
            'art of the {}.',
            'a drawing of the {}.',
            'a photo of the large {}.',
            'a black and white photo of a {}.',
            'the plushie {}.',
            'a dark photo of a {}.',
            'itap of a {}.',
            'graffiti of the {}.',
            'a toy {}.',
            'itap of my {}.',
            'a photo of a cool {}.',
            'a photo of a small {}.',
            'a tattoo of the {}.',
        ]


