import os, sys
import numpy as np
import pandas as pd
import time
import openai
import argparse

# openai.api_key = ''

def gpt_ft(train_data, num_agents, num_rounds, num_epochs = 4):
    print('GPT-3.5 Finetuning')
    print(f'Creating training file {train_data}')
    time.sleep(10) #Wait 10 seconds just to be sure training file is correct
    train_response = openai.File.create(
        file = open(train_data, 'rb'),
        purpose = 'fine-tune'
    )
    train_id = train_response['id']
    openai.File.wait_for_processing(id = train_id)

    print('Start finetuning job')
    #NOTE: Don't use more than 6 epochs. 2 epochs is enough. Keep spending low.
    ft_job = openai.FineTuningJob.create(
        training_file = train_id,
        model = 'gpt-3.5-turbo-0613',
        suffix = f'5k-fin-a{num_agents}-r{num_rounds}-{num_epochs}',
        hyperparameters = {
            'n_epochs': num_epochs
        }
    )

    job_id = ft_job['id']
    response = openai.FineTuningJob.retrieve(job_id)
    while response['status'] != 'succeeded':
        response = openai.FineTuningJob.retrieve(job_id)
        continue
    print('Finetuning finished!')

    response = openai.FineTuningJob.list_events(id = job_id, limit = 50)
    events = response['data']
    events.reverse()
    for event in events:
        print(event['message'])

    response = openai.FineTuningJob.retrieve(job_id)
    finetuned_model_id = response['fine_tuned_model']
    result_files_id = response['result_files'][0]
    return finetuned_model_id, result_files_id

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', action = 'store', type = int, default = 4, dest = 'epochs')
    args = parser.parse_args()
    num_agents = 4
    num_rounds = 3
    train_data = os.path.join('ft_data', '5k-ft-final-train_agents4_rounds3.jsonl')
    assert os.path.exists(train_data)

    model_id, result_files_id = gpt_ft(train_data, num_agents, num_rounds, num_epochs = args.epochs)
    print('MODEL ID:', model_id)
    with open('ft_data/model_names.txt', 'a') as f:
        f.write(f'{model_id}\n')
    print('FILE RESULTS ID:', result_files_id)