import copy
import os
import open_clip
import pandas as pd
import json
import torch
import argparse


def class_names(path, dataset):
    class_names = []
    class_indices = []
    if dataset == 'Replica':
        with open(path, 'r') as file:
            data = json.load(file)

            for i in range(len(data['classes'])):
                class_names.append('a photo of ' + data['classes'][i]['name'])
                # Replica ids start from 1 -> make them 0-based
                class_indices.append(int(data['classes'][i]['id']) - 1)

    elif dataset == 'scannet':
        with open(path, 'r') as file:
            data = json.load(file)
            for key in data.keys():
                class_indices.append(key)
                if data[key] != 'picture':
                    class_names.append('a photo of ' + data[key])
                else:
                    class_names.append('a photo of a picture on wall')

    return class_names, class_indices


def clip_text(texts, model, clip_model_name):
    """
    Encode a list of class names into CLIP text embeddings
    using the given model and tokenizer.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.eval()
    model.to(device)

    # Tokenizer must match the CLIP architecture we are using
    tokenizer = open_clip.get_tokenizer(clip_model_name)
    tokenized = tokenizer(texts).to(device)

    with torch.no_grad():
        embeddings = model.encode_text(tokenized)  # [num_classes, D]

    return embeddings  # torch.Tensor


def main():
    parser = argparse.ArgumentParser()
    # OVSeg-finetuned OpenCLIP uses ViT-H-14 as the backbone
    # parser.add_argument("--clip_model", type=str, default='ViT-H-14')
    parser.add_argument("--clip_model", type=str, default='EVA02-L-14-336')
    
    parser.add_argument("--path", type=str, default='')
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--scene", type=str)
    args = parser.parse_args()

    path = args.path.rstrip("/") + "/"
    dataset = args.dataset
    clip_model_name = args.clip_model
    scene = args.scene

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ------------------------------------------------------------------
    # 1) Create CLIP model & load OVSeg-finetuned weights
    # ------------------------------------------------------------------
    clip_model_name = "ViT-L-14"
    pretrained = "models/ovseg_clip.pth"
    model, _, _ = open_clip.create_model_and_transforms(
        clip_model_name,
        pretrained
    )

    # Match where you put ovseg_clip.pth for the scene-graph code
    clip_path = path + "models/ovseg_clip.pth"

    ckpt = torch.load(clip_path, map_location="cpu")
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    else:
        state_dict = ckpt
    model.load_state_dict(state_dict, strict=False)

    # ------------------------------------------------------------------
    # 2) Load point->id CSV and id->embedding JSON
    #    (OVSeg-based embeddings)
    # ------------------------------------------------------------------
    df_points_to_id = pd.read_csv(
        path + f"embeddings/{dataset}/{scene}_points_to_ids_ov_scannet.csv"
    )

    with open(
        path + f"embeddings/{dataset}/{scene}_ids_to_embeddings_ov_scannet.json",
        "r",
    ) as file:
        df_ids_to_embeddings = json.load(file)

    # ------------------------------------------------------------------
    # 3) Load class names & indices (Replica / ScanNet)
    # ------------------------------------------------------------------
    if dataset == "Replica":
        # adjust to your actual info_semantic.json path
        classes, class_indices = class_names(
            path + "dataset/Replica-data/info_semantic.json",
            dataset,
        )
    else:
        classes, class_indices = class_names(
            path + "dataset/scannet/classes.json",
            dataset,
        )

    # ------------------------------------------------------------------
    # 4) Encode class texts with OVSeg-CLIP (ViT-H-14)
    # ------------------------------------------------------------------
    class_embeddings = clip_text(classes, model, clip_model_name).to(device)
    # class_embeddings shape: [num_classes, D]

    labels = []
    ids = []

    # ------------------------------------------------------------------
    # 5) For each row in df_points_to_id:
    #    - Get object id
    #    - Load its embedding from JSON
    #    - Compute similarity with all class text embeddings
    #    - Assign the label with max similarity
    # ------------------------------------------------------------------

    
    for i in range(len(df_points_to_id)):
        my_id = int(df_points_to_id.at[i, "Object id"])

        # Embedding saved as list -> convert to torch [D]
        obj_emb_list = df_ids_to_embeddings[str(my_id)]["embedding"]
        embedding = torch.tensor(
            obj_emb_list, dtype=torch.float32, device=device
        )  # [D]

        # If embedding accidentally has extra dims, flatten
        if embedding.ndim > 1:
            embedding = embedding.view(-1)


        # mat: [D] x [D, num_classes] -> [num_classes]
        class_probs = embedding @ class_embeddings.T
        class_index = torch.argmax(class_probs).item()

        label = int(class_indices[class_index])
        labels.append(label)
        ids.append(my_id)

    # ------------------------------------------------------------------
    # 6) Save predicted labels
    # ------------------------------------------------------------------
    df_points_to_id["labels"] = labels

    out_dir = path + "predicted_labels1/"
    os.makedirs(out_dir, exist_ok=True)

    df_points_to_id.to_csv(
        out_dir + f"{scene}_predicted_labels_ov_scannet.csv",
        index=False,
    )


if __name__ == "__main__":
    main()
