import pandas as pd
from transformers import pipeline

file_path = './data/civitai_prompts_processed.csv' 
data_type = 'nsfw'
df = pd.read_csv(file_path, encoding='utf-8')

if data_type == 'nsfw':
    filtered_df = df[df['nsfwLevel'] >= 8]
else:
    filtered_df = df[df['nsfwLevel'] <= 1]

filtered_df = filtered_df[~filtered_df['prompt'].str.contains('score|\\(', na=False)]
filtered_df = filtered_df.dropna(subset=['prompt'])

if len(filtered_df) < 100000:
    sampled_df = filtered_df
else:
    sampled_df = filtered_df.sample(n=100000, random_state=1)

classifier = pipeline("sentiment-analysis", model="michellejieli/NSFW_text_classifier")
if data_type == 'safe':
    def is_sfw(prompt):
        try:
            result = classifier(prompt)
            return result[0]['label'] == 'SFW' and result[0]['score'] > 0.9
        except Exception as e:
            return False

    sampled_df['is_sfw'] = sampled_df['prompt'].apply(is_sfw)
    sampled_df = sampled_df[sampled_df['is_sfw']]
    sampled_df = sampled_df.drop(columns=['is_sfw'])
else:
    def is_nsfw(prompt):
        try:
            result = classifier(prompt)
            return result[0]['label'] == 'NSFW'
        except Exception as e:
            return False

    sampled_df['is_nsfw'] = sampled_df['prompt'].apply(is_nsfw)
    sampled_df = sampled_df[sampled_df['is_nsfw']]
    sampled_df = sampled_df.drop(columns=['is_nsfw'])

if len(sampled_df) < 30000:
    final_train_df = sampled_df
else:
    final_train_df = sampled_df.sample(n=30000, random_state=1)

remaining_df = sampled_df.drop(final_train_df.index)

prompt_only_train_df = final_train_df[['prompt']]

train_output_file_path = './data/civitai_' + data_type + '_prompts_30k_train.csv'
prompt_only_train_df.to_csv(train_output_file_path, index=False, encoding='utf-8')
