import torch
import os
import glob
import pandas as pd
import json
import pandas as pd
import matplotlib.pyplot as plt

DATASET_TYPE = "train"
NUM_CLASSES = 1659
SCANNET_MAIN_DIR = f"./ov-seg-clip/open_clip_training/openclip_data/scannetpp/scannetpp_{DATASET_TYPE}_subsetv2"
OUT_DIR = "./ov-seg-clip/open_clip_training/openclip_data/scannetpp"

CLASSES_FILE = open("./data/scannet++/semantic_classes.json")
SCANNET_CATEGORIES = json.load(CLASSES_FILE)
SCANNET_CLASSES = [x["name"] for x in SCANNET_CATEGORIES]

def main():
    ade_dict = {'title': [], 'filepath': []}
    classes = SCANNET_CLASSES
    appearing_classes = []
    assert len(classes) == NUM_CLASSES
    for i in range(0, NUM_CLASSES):
        images = glob.glob(os.path.join(SCANNET_MAIN_DIR, str(i), "*"))
        if len(images) == 0:
            continue
        else:
            appearing_classes.append(classes[i])
        print("Len %d for class %d" % (len(images), i))
        for image in images:
            ade_dict['title'].append(f"a photo of a {classes[i]}")
            ade_dict['filepath'].append(image)
    df = pd.DataFrame.from_dict(ade_dict)
    print(df.head())
    print(len(df))
    print(df.info())
    df.to_csv(os.path.join(OUT_DIR, f"scannetpp_{DATASET_TYPE}_gt.csv"), sep="\t")

    with open(os.path.join(OUT_DIR, f"scannetpp_{DATASET_TYPE}_classes.txt"), 'w') as fp:
        for item in appearing_classes:
            # write each item on a new line
            fp.write("%s\n" % item)

main()