import os
import openai
import pandas as pd
import math
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import time

group_categories = {
    'gender': ['men', 'women'],
    'race': ['black', 'white', 'asian', 'hispanic'],
    'religion': ['judaism', 'islam', 'christianity']
}

openai.api_key = os.getenv("OPENAI_API_KEY")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
percent = 0.5

cat = 'religion'
ratios = [0.1, 0.5, 0.9]

for ratio in ratios:
    for g1 in group_categories[cat]:
        for g2 in group_categories[cat]:
            if g1 == g2:
                continue
            print(ratio, g1, g2)
            df = pd.read_csv(
                f'/home/user/project_2021_nlp-summarization-bias/data/synthetic_data/articles/multigroup/{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.csv', index_col=0)
            articles = df.article.to_list()
            summaries = []

            for i, article in tqdm(enumerate(articles), total=len(articles)):
                art_length = len(tokenizer(article).input_ids)
                max_length = math.ceil(percent*art_length)

                response = openai.Completion.create(
                    model="text-davinci-002",
                    # model="text-curie-001",
                    prompt=article,
                    temperature=0.5,
                    max_tokens=max_length,
                    top_p=1.0,
                    frequency_penalty=0.0,
                    presence_penalty=0.0
                )

                summary = response.choices[0].text.replace('\n', ' ')
                summaries.append(summary)

            df['summary'] = summaries
            df.to_csv(
                f'/home/user/project_2021_nlp-summarization-bias/data/synthetic_data/gpt3/multigroup/{percent:.1f}_{ratio:.2f}_{g1}_{1-ratio:.2f}_{g2}.csv')
