import copy
import open_clip
import pandas as pd
import json
import torch
# from sklearn.cluster import DBSCAN
import argparse



def class_names(path,dataset):
    class_names = []
    class_indices = []
    if(dataset=='Replica'):
        with open(path, 'r') as file:
            data = json.load(file)

            for i in range(len(data['classes'])):
                class_names.append('a photo of ' + data['classes'][i]['name'])
                class_indices.append(int(data['classes'][i]['id'])-1)
    elif(dataset=='ScanNet'):
        with open(path, 'r') as file:
            data = json.load(file)
            for key in data.keys():
                class_indices.append(key)
                if(data[key]!='picture'):
                    class_names.append('a photo of '+data[key])

                else:
                    class_names.append('a photo of a picture on wall')

    return class_names,class_indices

def dbscan_3d(points, eps, min_samples):
    # Run DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
    # Get labels (-1 = noise)
    labels = clustering.labels_

    return labels

def clip_text(text,model):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.eval()
    model.to(device) # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
    tokenizer = open_clip.get_tokenizer('ViT-H-14')
    text = tokenizer(text).to(device)
    with torch.no_grad():
        embeddings = torch.tensor(model.encode_text(text))

    return embeddings

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--clip_model", type=str,default='EVA02-L-14-336')
    parser.add_argument("--path", type=str,default='')
    parser.add_argument("--dataset", type=str)
    parser.add_argument('--scene',type=str)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    clip_model = args.clip_model
    scene = args.scene
    # dataset_path = args.dataset_path
    # dataset = 'Replica'
    
    # scene = 'room0'
    # path = '/home/mamamin/Desktop'
    # clip_model = 'EVA02-L-14-336'
    model, _, preprocess = open_clip.create_model_and_transforms(clip_model,pretrained=None)
    model.load_state_dict(torch.load(path+'/models/open_clip_pytorch_model.bin'))
    # df_points_to_id = pd.read_csv(path+'/embeddings/'+dataset+'/'+scene+'_points_to_ids.csv')
    # df_points_to_id = pd.read_csv(path+'/embeddings/'+dataset+'/'+scene+'_points_to_ids_sam.csv')
    df_points_to_id = pd.read_csv(path+'/embeddings/'+dataset+'/'+scene+'_points_to_ids_gran_6_scannet.csv')
    # path = '/home/mamamin/Desktop/Robotics/'+scene+'_ids_to_embeddings_'+str(id)+'.json'

    # with open(path+'/embeddings/'+dataset+'/'+scene+'_ids_to_embeddings.json', 'r') as file:
    #     df_ids_to_embedings = json.load(file)
    # with open(path+'/embeddings/'+dataset+'/'+scene+'_ids_to_embeddings_sam.json', 'r') as file:
    #     df_ids_to_embedings = json.load(file)
    with open(path+'/embeddings/'+dataset+'/'+scene+'_ids_to_embeddings_gran_6_scannet.json', 'r') as file:
        df_ids_to_embedings = json.load(file)


    # classes,class_indices = class_names('/home/mamamin/Replica-Dataset/room_0/habitat/info_semantic.json',dataset)
    if(dataset=='Replica'):
        # classes, class_indices = class_names(path+'/datasets/Replica-Dataset/info_semantic.json', dataset)
        # classes, class_indices = class_names(path+'/dataset/Replica-Dataset/info_semantic.json', dataset)
        classes, class_indices = class_names(path+'/dataset/Replica-data/info_semantic.json', dataset)
    else:
        # classes, class_indices = class_names(path + '/datasets/ScanNet/classes.json', dataset)
        classes, class_indices = class_names(path + '/dataset/ScanNet/classes.json', dataset)
    # points = df_points_to_id[['x', 'y', 'z']].to_numpy()
    # object_ids = np.unique(df_points_to_id['Object id'])
    labels = []
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    class_embeddings = torch.tensor(clip_text(classes,model),dtype=float).to(device)
    ids = []

    for i in range(len(df_points_to_id)):
        my_id = df_points_to_id.at[i,'Object id']
        embedding = torch.tensor(df_ids_to_embedings[str(my_id)]['embedding']).double().to(device)
        class_probs = embedding @ class_embeddings.T
        class_index = torch.argsort(class_probs)[-1]
        # class_name = classes[class_index]
        label = class_indices[class_index]
        # label = class_index
        labels.append(int(label))
        ids.append(int(my_id))
    #
    df_points_to_id['labels'] = labels
    # df_points_to_id.to_csv(path+'/predicted_labels1/'+scene+'_predicted_labels.csv', index=False)
    # df_points_to_id.to_csv(path+'/predicted_labels1/'+scene+'_predicted_labels_sam.csv', index=False)
    df_points_to_id.to_csv(path+'/predicted_labels1/'+scene+'_predicted_labels_gran6_scannet.csv', index=False)

if __name__ == '__main__':
    main()