import os
import numpy as np
import clip
import torch
import pdb
# from eval_semantic_instance import evaluate
import tqdm
import argparse
import json
from omegaconf import OmegaConf
from clip_adapter.clip_opendas import build_model as build_opendas, load_clip_to_cpu as load_opendas_clip_to_cpu, load_model as load_opendas
from scannetpp_constants import CLASS_LABELS_IGNORE, VALID_CLASS_IDS_SCANNETPP_VAL, CLASS_LABELS_SCANNETPP_VAL

LABEL_TO_ID = {lbl:idx for idx, lbl in zip(VALID_CLASS_IDS_SCANNETPP_VAL, CLASS_LABELS_SCANNETPP_VAL)}

INCLUDE_WALL_ETC = False

if INCLUDE_WALL_ETC:
    VALID_INST_LABELS = CLASS_LABELS_SCANNETPP_VAL
    VALID_INST_INDICES = [LABEL_TO_ID[el] for el in VALID_INST_LABELS]
else:
    VALID_INST_LABELS = [el for el in CLASS_LABELS_SCANNETPP_VAL if el not in CLASS_LABELS_IGNORE]
    VALID_INST_INDICES = [LABEL_TO_ID[el] for el in VALID_INST_LABELS] 
CLASS_LABELS = VALID_INST_LABELS

print(f"[INFO] Valid class labels: {CLASS_LABELS}")
print(f"[INFO] Valid class indices: {VALID_INST_INDICES}")

def write_json(path, data):
    with open(path, "w") as f:
        f.write(json.dumps(data, indent=4))

def rle_encode(mask):
    """Encode RLE (Run-length-encode) from 1D binary mask.

    Args:
        mask (np.ndarray): 1D binary mask
    Returns:
        rle (dict): encoded RLE
    """
    length = mask.shape[0]
    mask = np.concatenate([[0], mask, [0]])
    runs = np.where(mask[1:] != mask[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    counts = ' '.join(str(x) for x in runs)
    rle = dict(length=length, counts=counts)
    return rle

class InstSegEvaluator():
    def __init__(self, dataset_type, clip_model_type, sentence_structure, build_custom_clip):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print("[INFO] Device:", self.device)
        self.dataset_type = dataset_type
        self.clip_model_type = clip_model_type
        self.build_custom_clip = build_custom_clip
        self.clip_model = self.get_clip_model(clip_model_type, build_custom_clip=self.build_custom_clip)
        self.feature_size = self.get_feature_size(clip_model_type)
        print("[INFO] Feature size:", self.feature_size)
        print("[INFO] Getting label mapper...")
        self.set_label_and_color_mapper(dataset_type)
        print("[INFO] Got label mapper...")
        print("[INFO] Loading query sentences...")
        self.query_sentences = self.get_query_sentences(dataset_type, sentence_structure)
        print("[INFO] Loaded query sentences.")
        print("[INFO] Computing text query embeddings...")
        self.text_query_embeddings = self.get_text_query_embeddings().numpy() #torch.Size([20, 768])
        print("[INFO] Computed text query embeddings.")
        print("[INFO] Shape of query embeddings matrix:", self.text_query_embeddings.shape)

    def set_label_and_color_mapper(self, dataset_type):
        if dataset_type == 'scannetpp100':
            print(f"VALID_INST_LABELS: {VALID_INST_LABELS}")
            print(f"VALID_INST_INDICES: {VALID_INST_INDICES}")
            self.label_mapper = np.vectorize({idx: el for idx, el in enumerate(VALID_INST_INDICES)}.get)
            #self.color_mapper = np.vectorize(SCANNET_COLOR_MAP_20.get)
        else:
            raise NotImplementedError

    def get_query_sentences(self, dataset_type, sentence_structure="a {} in a scene"):
        label_list = list(CLASS_LABELS)
        #label_list[-1] = 'other' # replace otherfurniture with other, following OpenScene
        return [sentence_structure.format(label) for label in label_list]

    def get_clip_model(self, clip_model_type, build_custom_clip=False):
        if build_custom_clip:
            cfg = OmegaConf.create({
                "MODEL": {
                    "CLIP_ADAPTER": {
                        "CLIP_MODEL_NAME": clip_model_type
                    },
                    "OPENDAS": {
                        "DIR": "./multimodal-prompt-learning/output/scannetpp_similar_negative_v2/OpenDAS/vit_l14_c2_ep10_batch16_2+2ctx_d24_use_both_losses_0shots/seed429",
                        "LOAD_EPOCH": 8,
                        "PROMPT_DEPTH_VISION": 24,
                        "PROMPT_DEPTH_TEXT": 24,
                        "N_CTX_TEXT": 4,
                        "N_CTX_VISION": 8,
                        "CTX_INIT": "a photo of a",
                        "INPUT_SIZE": (224, 224)   
                    }
                }
            })
            print(f"[INFO] Loading Custom CLIP with {cfg}...")

            clip_model = load_opendas_clip_to_cpu(cfg).type(torch.float32)
            class_names = CLASS_LABELS # should have the same order as in the query label
            custom_clip = build_opendas(cfg, class_names, clip_model)
            custom_clip = load_opendas(custom_clip, cfg) # loads the weights
            print(f"[INFO] Custom CLIP Loaded.")
            return custom_clip
        
        clip_model, _ = clip.load(clip_model_type, self.device)
        return clip_model

    def get_feature_size(self, clip_model_type):
        if clip_model_type == 'ViT-L/14' or clip_model_type == 'ViT-L/14@336px':
            return 768
        elif clip_model_type == 'ViT-B/32':
            return 512
        else:
            raise NotImplementedError

    def get_text_query_embeddings(self):
        if self.build_custom_clip:
            return self.clip_model.get_text_features(range(len(self.query_sentences))).float().cpu()

        # ViT_L14_336px for OpenSeg, clip_model_vit_B32 for LSeg
        text_query_embeddings = torch.zeros((len(self.query_sentences), self.feature_size))

        for label_idx, sentence in enumerate(self.query_sentences):
            #print(label_idx, sentence) #CLASS_LABELS_20[label_idx],
            text_input_processed = clip.tokenize(sentence).to(self.device)
            with torch.no_grad():
                sentence_embedding = self.clip_model.encode_text(text_input_processed)

            sentence_embedding_normalized =  (sentence_embedding/sentence_embedding.norm(dim=-1, keepdim=True)).float().cpu()
            text_query_embeddings[label_idx, :] = sentence_embedding_normalized

        return text_query_embeddings


    def compute_classes_per_mask_diff_scores(self, masks_path, mask_features_path, keep_first=None, scores_path=None):
        # pred_masks = np.load(masks_path)
        pred_masks = torch.load(masks_path)
        mask_features = np.load(mask_features_path)

        keep_mask = np.asarray([True for el in range(pred_masks.shape[1])])
        if keep_first:
            keep_mask[keep_first:] = False

        # normalize mask features
        mask_features_normalized = mask_features/np.linalg.norm(mask_features, axis=1)[..., None]
        mask_features_normalized[np.isnan(mask_features_normalized) | np.isinf(mask_features_normalized)] = 0.0

        per_class_similarity_scores = mask_features_normalized@self.text_query_embeddings.T #(177, 20)
        max_ind = np.argmax(per_class_similarity_scores, axis=1)
        max_ind_remapped = self.label_mapper(max_ind)

        pred_masks = pred_masks[:, keep_mask]
        pred_classes = max_ind_remapped[keep_mask]

        if scores_path is not None:
            orig_scores = np.load(scores_path)
            pred_scores = orig_scores[keep_mask]
        else:
            pred_scores = np.ones(pred_classes.shape)

        return pred_masks, pred_classes, pred_scores

    def evaluate_full(self, preds, scene_gt_dir, dataset, output_file='temp_output.txt'):
        #pred_masks.shape, pred_scores.shape, pred_classes.shape #((237360, 177), (177,), (177,))

        inst_AP = evaluate(preds, scene_gt_dir, output_file=output_file, dataset=dataset)
        # read .txt file: scene0000_01.txt has three parameters each row: the mask file for the instance, the id of the instance, and the score. 

        return inst_AP

def export_preds_in_scannet_format(pred_masks, pred_classes, pred_scores, pred_export_dir, scene_name):
    main_export_dir = pred_export_dir.format(scene_name)
    inst_masks_dir = os.path.join(main_export_dir, "predicted_masks")
    if not os.path.exists(inst_masks_dir):
        os.makedirs(inst_masks_dir)

    main_txt_file = os.path.join(main_export_dir, f'{scene_name}.txt')
    num_masks = pred_masks.shape[1]
    inst_ids = np.asarray(range(num_masks+1))
    inst_ids = inst_ids[inst_ids>0]
    main_txt_lines = []

    #pdb.set_trace()
    # for each instance
    for inst_ndx, inst_id in enumerate(sorted(inst_ids)):
        assert inst_ndx+1==inst_id
        # get the mask for the instance
        inst_mask = pred_masks[:, inst_ndx]
        # get the semantic label for the instance
        inst_sem_label = pred_classes[inst_ndx]
        # add a line to the main file with relative path
        # predicted_masks <semantic label> <confidence=1>
        mask_path_relative = f'predicted_masks/{scene_name}_{inst_ndx:03d}.json'
        inst_pred_score = pred_scores[inst_ndx]
        main_txt_lines.append(f'{mask_path_relative} {inst_sem_label} {inst_pred_score}') #main_txt_lines.append(f'{mask_path_relative} {inst_sem_label} 1.0')
        # save the instance mask to a file in the predicted_masks dir
        mask_path = os.path.join(main_export_dir, mask_path_relative)
        write_json(mask_path, rle_encode(inst_mask))    

    #pdb.set_trace()
    # save the main txt file
    with open(main_txt_file, 'w') as f:
        f.write('\n'.join(main_txt_lines))   
    """
                    # create main txt file
                main_txt_file = inst_predsformat_out_dir / f'{scene_id}.txt'
                # get the unique and valid instance IDs in inst_gt 
                # (ignore invalid IDs)
                inst_ids = np.unique(inst_gt)
                inst_ids = inst_ids[inst_ids > 0]
                # main txt file lines
                main_txt_lines = []

                # create the dir for the instance masks
                inst_masks_dir = inst_predsformat_out_dir / 'predicted_masks'
                inst_masks_dir.mkdir(parents=True, exist_ok=True)

                # for each instance
                for inst_ndx, inst_id in enumerate(tqdm(sorted(inst_ids))):
                # get the mask for the instance
                    inst_mask = inst_gt == inst_id
                    # get the semantic label for the instance
                    inst_sem_label = sem_gt[inst_mask][0]
                    # add a line to the main file with relative path
                    # predicted_masks <semantic label> <confidence=1>
                    mask_path_relative = f'predicted_masks/{scene_id}_{inst_ndx:03d}.json'
                    main_txt_lines.append(f'{mask_path_relative} {inst_sem_label} 1.0')
                    # save the instance mask to a file in the predicted_masks dir
                    mask_path = inst_predsformat_out_dir / mask_path_relative
                    write_json(mask_path, rle_encode(inst_mask))

                # save the main txt file
                with open(main_txt_file, 'w') as f:
                    f.write('\n'.join(main_txt_lines))

    """

def test_pipeline_full_scannetpp100(mask_features_dir,
                                    gt_dir,
                                    pred_mask_dir,
                                    sentence_structure,
                                    feature_file_template,
                                    pred_export_dir,
                                    dataset_type='scannetpp100',
                                    clip_model_type='ViT-L/14',
                                    keep_first = None,
                                    scene_list_file='',
                                    masks_template='{}.npy',
                                    scores_dir=None,
                                    scores_template='{}.npy',
                                    build_custom_clip=False
                                ):


    evaluator = InstSegEvaluator(dataset_type, clip_model_type, sentence_structure, build_custom_clip)
    print('[INFO]', dataset_type, clip_model_type, sentence_structure)

    with open(scene_list_file, 'r') as f:
        scene_names = f.read().splitlines()
    scene_names = sorted(scene_names)
    print(f"[INFO] Scenes: {scene_names}")

    preds = {}

    if os.path.exists(pred_export_dir):
        print("Warning! Pred export dir already exists! - ", pred_export_dir)
        #raise Exception("Pred export dir already exists! - ", pred_export_dir)
    else:
        os.makedirs(pred_export_dir)

    for scene_name in tqdm.tqdm(scene_names[:]):
        masks_path = os.path.join(pred_mask_dir, masks_template.format(scene_name))
        scene_per_mask_feature_path = os.path.join(mask_features_dir.format(scene_name), feature_file_template)
        if scores_dir is not None:
            scores_path = os.path.join(scores_dir, scores_template.format(scene_name))
        else:
            scores_path = None

        if not os.path.exists(scene_per_mask_feature_path):
            print('--- SKIPPING ---', scene_per_mask_feature_path)
            continue
        pred_masks, pred_classes, pred_scores = evaluator.compute_classes_per_mask_diff_scores(masks_path=masks_path, 
                                                                                               mask_features_path=scene_per_mask_feature_path,
                                                                                               keep_first=keep_first,
                                                                                               scores_path=scores_path)
        
        export_preds_in_scannet_format(pred_masks, pred_classes, pred_scores, pred_export_dir, scene_name)

        #pdb.set_trace()
        #preds[scene_name] = {
        #    'pred_masks': pred_masks,
        #    'pred_scores': pred_scores,
        #    'pred_classes': pred_classes}

    #inst_AP = evaluator.evaluate_full(preds, gt_dir, dataset=dataset_type)


if __name__ == '__main__':
    
    # Run predictions for scannet++ offices
    #'''
    SCENE_LIST_VAL = "./data/scannet++/scenes_val.txt"
    PRED_MASK_DIR = "./openmask3d/output/2024-03-23-15-14-28-scannetpp_office/masks"
    GT_DIR = "./scannetpp_openmask3d/scannetpp/GT_INST_100"
    CLIP_MODEL_TYPE = "ViT-L/14"
    USE_OPENDAS = True
    OPENMASK3D_FREQUENCY = 20
   
    if USE_OPENDAS:
        NUM_LEVELS = 1
        method = "opendas"
    else:
        NUM_LEVELS = 3
        method = "clip"
    
    results_dir = f"./scannetpp_masks_and_results_with_{method}_num_levels_{NUM_LEVELS}_frequency_{OPENMASK3D_FREQUENCY}/"
    mask_features_dir = results_dir + "{}/masks_results"
    pred_export_dir = results_dir + f"predictions_with_{method}"

    test_pipeline_full_scannetpp100(mask_features_dir=mask_features_dir,
                                gt_dir=GT_DIR,
                                pred_mask_dir=PRED_MASK_DIR,
                                sentence_structure="a {} in a scene",
                                feature_file_template='masks_openmask3d_features.npy',
                                clip_model_type=CLIP_MODEL_TYPE,
                                scene_list_file=SCENE_LIST_VAL,
                                masks_template='{}_masks.pt',
                                pred_export_dir=pred_export_dir,
                                build_custom_clip=USE_OPENDAS
                         )
    #''' 

    # test_pipeline_full_scannetpp100(mask_features_dir="./openmask3d/output/2024-03-01-11-44-38-scannetpp/",
    #                             gt_dir="./scannetpp_openmask3d/scannetpp/GT_INST_100",
    #                             pred_mask_dir="./openmask3d/output/2024-03-01-11-44-38-scannetpp/",
    #                             sentence_structure="a {} in a scene",
    #                             feature_file_template='{}_openmask3d_features.npy',
    #                             clip_model_type='ViT-L/14',
    #                             scene_list_file='./scannetpp/scene_list.txt',
    #                             masks_template='{}_masks.pt',
    #                             pred_export_dir="./scannetpp_masks_and_results/28a9ee4557",
    #                     )