
from collections import Counter
from nltk.corpus import stopwords
import nltk


import torch
from PIL import Image
from lavis.models import load_model_and_preprocess
import torch
from copy import deepcopy
import os
from tqdm import tqdm

device = torch.device("cuda:7")
model, vis_processors, _ = load_model_and_preprocess(name="texpel_caption", model_type="base_coco", is_eval=True, device=device)

directory = '/PATH/landbirds'
caption_path = './PATH/captions_temp.txt'
problematic_images_path = './PATH/.txt'

add_to_stopwords = {'next', 'front', 'rear', 'besides', 'below', 'under',
                    'near', 'back', 'side', 'near', 'background', 'foreground',
                    'behind', 'along', 'top', 'small',
                    'large', 'sitting', 'driving', 'riding', 'laying', 'standing', 'looking', 'holding', 'wearing', 'outside',
                    'inside', 'another', 'together', 'old', 'open', 'close',
                    'new', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'like',
                    'looks', 'owner', 'cute', 'home', 'day', 'love', 'little', 'around',
                    'time', 'world', 'happy', 'big', 'get', 'look', 'head',
                    'eating', 'hat', 'likes', 'got', 'sleeping', 'go', 'today', 'think', 'put', 'really', 'flying',
                    'area', 'body', 'coming', 'going', 'blowing',
                    'traveling', 'moving'}

txt_file = open(problematic_images_path, 'w')

# iterate over files
for filename in tqdm(os.listdir(directory)):
    f = os.path.join(directory, filename)
    if os.path.isfile(f):
        image = Image.open(f).convert('RGB')
        image = vis_processors["eval"](image).unsqueeze(0).to(device)
        captions = model.generate({"image": image})

        file = open(caption_path, 'w')
        for sentence in captions:
            file.write(sentence + "\n")
        file.close()

        with open(caption_path, 'r') as file:
            texts = file.read().replace('\n', '')

        words = nltk.word_tokenize(texts)
        words_lower = [word.lower() for word in words]
        words_alpha = [word for word in words_lower if word.isalnum()]
        stop_words = set(stopwords.words('english'))
        stop_words.update(add_to_stopwords)
        filtered_words = [word for word in words_alpha if word not in stop_words]
        word_counts = Counter(filtered_words)
        word_counts_sorted = dict(sorted(word_counts.items(), key=lambda item: item[1], reverse=True))

        if list(word_counts_sorted.items())[0][0] not in ['bird', 'birds']:
            txt_file.write(filename + "\n")
txt_file.close()


