from modules.hoi4abot.hoibot.modules.semantic_extractor.clip_module import SemanticExtractor
import os
import torch
import json

def prepare_files(annotation_dir, template="obj_categories.json"):
    with open(os.path.join(annotation_dir, template.format(""))) as f:
        gt = json.load(f)
    return gt

def preprocess_text(texts, add_template="", remove_char=[("\n", "")]):
    """
    Preprocessing text before sending to CLIP
    :param texts: input texts
    :param add_template: templtaes to add. Such as: "A picture of a "
    :param remove_char: charsets to remove, such as remove ".", or "\n"
    :return:
    """
    texts = [f"{add_template}{f}" for f in texts]
    for rem, subst in remove_char:
        texts = [f.replace(rem, subst) for f in texts]
    return texts


if __name__ == '__main__':
    from configs.paths import annotations_dir
    semantic_extractor = SemanticExtractor(device="cuda:0", use_clip=True, clip_model="ViT-L/14",
                                           annotation_dir=annotations_dir)
    ## EXTRACT OBJECT FEATURES
    object_gt = prepare_files(annotations_dir, "obj_categories.json")

    adapted_object_text = preprocess_text(object_gt, add_template="", remove_char=[("\n", ""), ("_", " ")])
    object_features = semantic_extractor(object_gt)  # Dimensions: [B, 512]

    object_data = {}
    for index, (gt, t, feat) in enumerate(zip(object_gt, adapted_object_text, object_features)):
        object_data[gt] = {"index": index, "text": t, "feat": feat.detach().cpu()}

    torch.save(object_data, os.path.join(annotations_dir, "object_classes.pt"))


    ### EXTRACT RELATIONSHIP FEATURES
    relationships_gt = prepare_files(annotations_dir, "pred_categories.json")

    adapted_relationship_text = preprocess_text(relationships_gt, add_template="", remove_char=[("\n", ""), ("_", " ")])
    relationship_features = semantic_extractor(adapted_relationship_text)  # Dimensions: [B, 512]

    relationship_data = {}
    for index, (gt, t, feat) in enumerate(zip(relationships_gt, adapted_relationship_text, relationship_features)):
        relationship_data[gt] = {"index": index, "text": t, "feat": feat.detach().cpu()}

    torch.save(relationship_data, os.path.join(annotations_dir, "relationship_classes.pt"))