import torch
import os
import glob
import pandas as pd
import json
import pandas as pd
import matplotlib.pyplot as plt
from kitti360_labels import labels

DATASET_TYPE = "train"
NUM_CLASSES = 45
KITTI_MAIN_DIR = f"./ov-seg-clip/open_clip_training/openclip_data/kitti360/kitti_360_{DATASET_TYPE}"
OUT_DIR = "./ov-seg-clip/open_clip_training/openclip_data/kitti360"

def main():
    ade_dict = {'title': [], 'filepath': []}
    classes = [label.name for label in labels[:-1]]
    appearing_classes = []
    assert len(classes) == NUM_CLASSES
    for i in range(0, NUM_CLASSES):
        images = glob.glob(os.path.join(KITTI_MAIN_DIR, str(i), "*"))
        if len(images) == 0:
            continue
        else:
            appearing_classes.append(classes[i])
        print("Len %d for class %d" % (len(images), i))
        for image in images:
            ade_dict['title'].append(f"a photo of a {classes[i]}")
            ade_dict['filepath'].append(image)
    df = pd.DataFrame.from_dict(ade_dict)
    print(df.head())
    print(len(df))
    print(df.info())
    df.to_csv(os.path.join(OUT_DIR, f"kitti360_{DATASET_TYPE}_gt.csv"), sep="\t")

    with open(os.path.join(OUT_DIR, f"kitti360_{DATASET_TYPE}_classes.txt"), 'w') as fp:
        for item in appearing_classes:
            # write each item on a new line
            fp.write("%s\n" % item)

main()