import torch
import clip
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# import matplotlib
from omegaconf import OmegaConf
from adjustText import adjust_text  
from scannetpp_constants import CLASS_LABELS_SCANNETPP_VAL, CLASS_100_CLEAN
from clip_adapter.clip_opendas import build_model as build_opendas, load_clip_to_cpu as load_opendas_clip_to_cpu, load_model as load_opendas

# import matplotlib.font_manager
# print(matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf'))

plt.style.use('seaborn')
# plt.rcParams['font.family'] = 'Georgia'

# CLASS_LABELS_SCANNETPP_VAL = CLASS_100_CLEAN
USE_OPENDAS = True
CLASS_LABELS_SCANNETPP_VAL = CLASS_LABELS_SCANNETPP_VAL
LABELS_TO_SHOW = ["wall", "wall socket", "ceiling", "ceiling beam", "floor", "carpet", "webcam", "monitor", "door", "door frame", "objects"]

def get_clip_model(clip_model_type, device, build_custom_clip=False):
    if build_custom_clip:
        cfg = OmegaConf.create({
            "MODEL": {
                "CLIP_ADAPTER": {
                    "CLIP_MODEL_NAME": clip_model_type
                },
                "OPENDAS": {
                    "DIR": "./multimodal-prompt-learning/output/scannetpp_similar_negative_v2/OpenDAS/vit_l14_c2_ep10_batch16_2+2ctx_d24_use_both_losses_0shots/seed429",
                    "LOAD_EPOCH": 8,
                    "PROMPT_DEPTH_VISION": 24,
                    "PROMPT_DEPTH_TEXT": 24,
                    "N_CTX_TEXT": 4,
                    "N_CTX_VISION": 8,
                    "CTX_INIT": "a photo of a",
                    "INPUT_SIZE": (224, 224)   
                }
            }
        })
        print(f"[INFO] Loading Custom CLIP with {cfg}...")

        clip_model = load_opendas_clip_to_cpu(cfg).type(torch.float32)
        class_names = CLASS_LABELS_SCANNETPP_VAL # should have the same order as in the query label
        custom_clip = build_opendas(cfg, class_names, clip_model)
        custom_clip = load_opendas(custom_clip, cfg) # loads the weights
        print(f"[INFO] Custom CLIP Loaded.")
        return custom_clip
    
    clip_model, _ = clip.load(clip_model_type, device)
    return clip_model

def get_text_encodings(clip_model, texts, device, feature_size=768):
    if USE_OPENDAS:
        return clip_model.get_text_features(range(len(texts))).float().cpu()

    # ViT_L14_336px for OpenSeg, clip_model_vit_B32 for LSeg
    text_query_embeddings = torch.zeros((len(texts), feature_size))

    for label_idx, sentence in enumerate(texts):
        text_input_processed = clip.tokenize(sentence).to(device)
        with torch.no_grad():
            sentence_embedding = clip_model.encode_text(text_input_processed)

        sentence_embedding_normalized =  (sentence_embedding/sentence_embedding.norm(dim=-1, keepdim=True)).float().cpu()
        text_query_embeddings[label_idx, :] = sentence_embedding_normalized

    return text_query_embeddings

def generate_text_embeddings(texts):
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = get_clip_model("ViT-L/14", device=device, build_custom_clip=USE_OPENDAS)
    
    return get_text_encodings(model, texts, device).numpy()

def tsne_visualization(embeddings, labels):
    # Perform t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(embeddings)
    
    # Plotting
    plt.figure(figsize=(8, 8))
    texts = []
    for i, label in enumerate(labels):
        plt.scatter(tsne_results[i, 0], tsne_results[i, 1])
        if label not in LABELS_TO_SHOW:
            continue
        texts.append(plt.text(tsne_results[i, 0], tsne_results[i, 1], label, fontdict = {'fontsize' : 13},
                     ha='center', va='center'))
    
    adjust_text(texts, arrowprops=dict(arrowstyle='->', color='red'))
    plt.xlabel("t-SNE feature 0", fontsize=15)
    plt.ylabel("t-SNE feature 1", fontsize=15)
    plt.title("Text Embeddings t-SNE Visualization", fontsize=17)
    plt.savefig(f"out/visualizations/tsne_viz_opendas_{USE_OPENDAS}.pdf")

# Example texts
# if USE_OPENDAS:
texts = CLASS_LABELS_SCANNETPP_VAL
# else:
# texts = ["A photo of a {}".format(classname) for classname in CLASS_LABELS_SCANNETPP_VAL]

# Generate embeddings
embeddings = generate_text_embeddings(texts)

# Visualize with t-SNE
tsne_visualization(embeddings, texts)
