import numpy as np
import sys
sys.path.append("/cluster/apps/nss/gcc-8.2.0/python/3.8.5/x86_64/lib64/python3.8/site-packages")
import open3d as o3d
import torch
import clip
import pdb
import matplotlib.pyplot as plt
# from constants import *
from omegaconf import OmegaConf
from clip_adapter.clip_opendas import build_model as build_opendas, load_clip_to_cpu as load_opendas_clip_to_cpu, load_model as load_opendas

CLASS_LABELS = ["folder cabinet", "other"]
SCENE_ID = "e91722b5a3"
USE_OPENDAS = True

class QuerySimilarityComputation():
    def __init__(self, build_custom_clip=False):
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.build_custom_clip = build_custom_clip
        self.clip_model = self.get_clip_model('ViT-L/14')

    def get_query_embedding(self, text_query):
        if self.build_custom_clip:
            return self.clip_model.get_text_features([0]).float().cpu().squeeze().numpy()
        
        text_input_processed = clip.tokenize(text_query).to(self.device)
        with torch.no_grad():
            sentence_embedding = self.clip_model.encode_text(text_input_processed)

        sentence_embedding_normalized =  (sentence_embedding/sentence_embedding.norm(dim=-1, keepdim=True)).float().cpu()
        return sentence_embedding_normalized.squeeze().numpy()

    def get_clip_model(self, clip_model_type):
        if self.build_custom_clip:
            cfg = OmegaConf.create({
                "MODEL": {
                    "CLIP_ADAPTER": {
                        "CLIP_MODEL_NAME": clip_model_type
                    },
                    "OPENDAS": {
                        "DIR": "./multimodal-prompt-learning/output/scannetpp_similar_negative_v2/OpenDAS/vit_l14_c2_ep10_batch16_2+2ctx_d24_use_both_losses_0shots/seed429",
                        "LOAD_EPOCH": 8,
                        "PROMPT_DEPTH_VISION": 24,
                        "PROMPT_DEPTH_TEXT": 24,
                        "N_CTX_TEXT": 4,
                        "N_CTX_VISION": 8,
                        "CTX_INIT": "a photo of a",
                        "INPUT_SIZE": (224, 224)   
                    }
                }
            })
            print(f"[INFO] Loading Custom CLIP with {cfg}...")

            clip_model = load_opendas_clip_to_cpu(cfg).type(torch.float32)
            class_names = CLASS_LABELS # should have the same order as in the query label
            custom_clip = build_opendas(cfg, class_names, clip_model)
            custom_clip = load_opendas(custom_clip, cfg) # loads the weights
            print(f"[INFO] Custom CLIP Loaded.")
            return custom_clip
        
        clip_model, _ = clip.load(clip_model_type, self.device)
        return clip_model
 
    def compute_similarity_scores(self, mask_features, text_query):
        text_emb = self.get_query_embedding(text_query)

        scores = np.zeros(len(mask_features))
        for mask_idx, mask_emb in enumerate(mask_features):
            mask_norm = np.linalg.norm(mask_emb)
            if mask_norm < 0.001:
                continue
            normalized_emb = (mask_emb/mask_norm)
            scores[mask_idx] = normalized_emb@text_emb

        return scores
    
    def get_per_point_colors_for_similarity(self, 
                                            per_mask_scores, 
                                            masks, 
                                            normalize_based_on_current_min_max=False, 
                                            normalize_min_bound=0.16, #only used for visualization if normalize_based_on_current_min_max is False
                                            normalize_max_bound=0.23, #only used for visualization if normalize_based_on_current_min_max is False
                                            background_color=(0.77, 0.77, 0.77)
                                        ):
        # get colors based on the openmask3d per mask scores
        non_zero_points = per_mask_scores!=0
        openmask3d_per_mask_scores_rescaled = np.zeros_like(per_mask_scores)
        pms = per_mask_scores[non_zero_points]

        # print(f"MAX SCORE: {pms.max()}")
        # normalize_max_bound = pms.max()

        # in order to be able to visualize the score differences better, we can use a normalization scheme
        if normalize_based_on_current_min_max: # if true, normalize the scores based on the min. and max. scores for this scene
            openmask3d_per_mask_scores_rescaled[non_zero_points] = (pms-pms.min()) / (pms.max() - pms.min())
        else: # if false, normalize the scores based on a pre-defined color scheme with min and max clipping bounds, normalize_min_bound and normalize_max_bound.
            new_scores = np.zeros_like(openmask3d_per_mask_scores_rescaled)
            new_indices = np.zeros_like(non_zero_points)
            new_indices[non_zero_points] += pms>normalize_min_bound
            new_scores[new_indices] = ((pms[pms>normalize_min_bound]-normalize_min_bound)/(normalize_max_bound-normalize_min_bound))
            openmask3d_per_mask_scores_rescaled = new_scores

        new_colors = np.ones((masks.shape[1], 3))*0 + background_color
        
        print("Masks length: ", len(masks))
        # print("openmask3d_per_mask_scores_rescaled: ", openmask3d_per_mask_scores_rescaled)
        for mask_idx, mask in enumerate(masks[::-1, :]):
            # get color from matplotlib colormap
            new_colors[mask>0.7, :] = plt.cm.jet(openmask3d_per_mask_scores_rescaled[len(masks)-mask_idx-1])[:3]

        return new_colors



def main():
    # --------------------------------
    # Set the paths
    # --------------------------------

    path_scene_pcd = f"./scannet++/data/{SCENE_ID}/scans/mesh_aligned_0.05.ply"
    path_pred_masks = f"./openmask3d/output/2024-03-23-15-14-28-scannetpp_office/masks/{SCENE_ID}_masks.pt"
    
    if USE_OPENDAS:
        path_openmask3d_features = f"./scannetpp_masks_and_results_with_opendas_num_levels_1_frequency_20/{SCENE_ID}/masks_results/masks_openmask3d_features.npy"
    else:
        path_openmask3d_features = f"./scannetpp_masks_and_results_with_clip_num_levels_3_frequency_20/{SCENE_ID}/masks_results/masks_openmask3d_features.npy"


    # path_scene_pcd = "./scannet++/data/036bce3393/scans/mesh_aligned_0.05.ply"
    # path_pred_masks = "./openmask3d/output/2024-03-23-15-14-28-scannetpp_office/masks/036bce3393_masks.pt"
    # path_openmask3d_features = "./scannetpp_masks_and_results_with_opendas_num_levels_1/036bce3393/masks_results/masks_openmask3d_features.npy"
    

    # --------------------------------
    # Load data
    # --------------------------------
    # load the scene pcd
    scene_pcd = o3d.io.read_point_cloud(path_scene_pcd)
    
    # load the predicted masks
    pred_masks = np.asarray(torch.load(path_pred_masks)).T # (num_instances, num_points)

    # load the openmask3d features
    openmask3d_features = np.load(path_openmask3d_features) # (num_instances, 768)

    # initialize the query similarity computer
    query_similarity_computer = QuerySimilarityComputation(build_custom_clip=USE_OPENDAS)
    

    # --------------------------------
    # Set the query text
    # --------------------------------
    query_text = CLASS_LABELS[0] # change the query text here


    # --------------------------------
    # Get the similarity scores
    # --------------------------------
    # get the per mask similarity scores, i.e. the cosine similarity between the query embedding and each openmask3d mask-feature for each object instance
    per_mask_query_sim_scores = query_similarity_computer.compute_similarity_scores(openmask3d_features, query_text)
    per_mask_other_sim_scores = query_similarity_computer.compute_similarity_scores(openmask3d_features, CLASS_LABELS[-1])
    # set sim score to 0 if the sim score is smaller than sim score for "other"
    per_mask_query_sim_scores[per_mask_query_sim_scores < per_mask_other_sim_scores] = 0.0
    # print(per_mask_query_sim_scores)

    # --------------------------------
    # Visualize the similarity scores
    # --------------------------------
    # get the per-point heatmap colors for the similarity scores
    per_point_similarity_colors = query_similarity_computer.get_per_point_colors_for_similarity(per_mask_query_sim_scores, pred_masks) # note: for normalizing the similarity heatmap colors for better clarity, you can check the arguments for the function get_per_point_colors_for_similarity
    print(per_point_similarity_colors)

    # visualize the scene with the similarity heatmap
    scene_pcd_w_sim_colors = o3d.geometry.PointCloud()
    scene_pcd_w_sim_colors.points = scene_pcd.points
    scene_pcd_w_sim_colors.colors = o3d.utility.Vector3dVector(per_point_similarity_colors)
    scene_pcd_w_sim_colors.estimate_normals()
    # o3d.visualization.draw_geometries([scene_pcd_w_sim_colors])
    # alternatively, you can save the scene_pcd_w_sim_colors as a .ply file
    o3d.io.write_point_cloud("out/data/scene_{}_{}_USE_OPENDAS_{}.ply".format(SCENE_ID, '_'.join(query_text.split(' ')), USE_OPENDAS), scene_pcd_w_sim_colors)

if __name__ == "__main__":
    main()
