import torch
from rtpt import RTPT
from clip_models import BaseNet
import CLIP.clip as clip
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from torchvision.transforms import Normalize
import matplotlib.pyplot as plt
import os
import cv2
import pandas as pd
from tqdm import tqdm
import glob
import pickle
from main.clip_models.baseline import initialize_model


def main():
    dataset_dir = '/workspace/datasets/yfcc100m'
    image_dirs = os.listdir(dataset_dir)
    for i, image_dir in enumerate(image_dirs):
        sentences = list()
        text_files = glob.glob(os.path.join(dataset_dir,image_dir) + '/*.txt')
        print(f'Reading dir {i}/{len(image_dirs)}')
        for text_file in tqdm(text_files):
            texts = open(text_file,'r').readline()
            texts = texts.split(' Tags: ')[0]
            texts = texts.split('.')
            sentences.append([text + '.' for text in texts])
        pickle.dump(sentences, open(os.path.join(dataset_dir, image_dir, 'sentences.p'), 'wb'))


if __name__ == '__main__':
    main()
