import os
import openai
import pandas as pd
import math
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import time

groups = ['men', 'women', 'black', 'white', 'asian', 'hispanic', 'judaism', 'islam', 'christianity']

openai.api_key = os.getenv("OPENAI_API_KEY")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
percent = 0.9
num_steps = 10

# num_steps-1 because we start with the original summaries generated
for i in range(8, num_steps-1):
    for g in groups:
        print(g)
        if i == 8 and g in ['men']:
            continue
        df = pd.read_csv(f'/home/user/project_2021_nlp-summarization-bias/data/synthetic_data/gpt3/recursive/{i}_{percent:.1f}_{g}.csv', index_col=0, lineterminator='\n')
        df.summary = df.summary.fillna('')
        # we use the summary as our "article" here
        df.article = df.summary
        articles = df.article.to_list()
        summaries = []

        for article in tqdm(articles, total=len(articles)):
            art_length = len(tokenizer(article).input_ids)
            max_length = math.ceil(percent*art_length)

            # Don't waste requests by asking for 0 length summaries of things
            if max_length > 0 and max_length < 4097:
                response = openai.Completion.create(
                    model="text-davinci-002",
                    #model="text-curie-001",
                    prompt=article,
                    temperature=0.5,
                    max_tokens=max_length,
                    top_p=1.0,
                    frequency_penalty=0.0,
                    presence_penalty=0.0
                )
                summary = response.choices[0].text.replace('\n', ' ')
            else:
                summary = ''

            summaries.append(summary)

        df['summary'] = summaries
        # i+1 because we start by loading from the original summaries, not generating them, so we're always a step ahead here
        df.to_csv(f'/home/user/project_2021_nlp-summarization-bias/data/synthetic_data/gpt3/recursive/{i+1}_{percent:.1f}_{g}.csv')