import torch
import os
import glob
import pandas as pd

NUM_CLASSES = 150
ADE_MAIN_DIR = "./ov-seg-clip/open_clip_training/openclip_data/ade20k_150/ade_gt_150cls_train_v3"
ADE_CLASS_PATH = "./ov-seg-clip/datasets/ade20k_150_with_prompt_eng.txt"
OUT_DIR = "./ov-seg-clip/open_clip_training/openclip_data/ade20k_150/"

def read_classes():
    classes = []
    with open(ADE_CLASS_PATH) as f:
        for line in f.readlines():
            class_str = line.split(":")[-1].strip()
            class_list = class_str.split(",")
            class_str = ", ".join(class_list)
            classes.append(class_str)
    return classes[1:]

def main():
    ade_dict = {'title': [], 'filepath': []}
    classes = read_classes()
    assert len(classes) == 150
    for i in range(0, NUM_CLASSES):
        images = glob.glob(os.path.join(ADE_MAIN_DIR, str(i), "*"))
        print(images[0])
        print("Len %d for class %d" % (len(images), i))
        for image in images:
            ade_dict['title'].append(f"a photo of a {classes[i]}")
            ade_dict['filepath'].append(image)
    df = pd.DataFrame.from_dict(ade_dict)
    print(df.head())
    print(len(df))
    print(df.info())
    df.to_csv(os.path.join(OUT_DIR, "ade_gt_150cls_v3.csv"), sep="\t")

main()