import argparse
from tqdm import tqdm
import pickle
import json
import openai

import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

# truncate prompt to 1024 tokens
def truncate_prompt(prompt):
    assert prompt[-8:] == ' <proof>'
    prompt = prompt[:-8] # get rid of ' <proof>', which is 3 tokens
    
    input_ids = tokenizer(prompt)['input_ids']
    
    # first, try to get rid of references
    while len(input_ids) > 1021 and prompt.endswith('</reference>'):
        prompt = prompt[:prompt.rfind('<reference>')-1] # -1 because of the trailing space
        input_ids = tokenizer(prompt)['input_ids']
    
    # if still overflows, truncate the theorem content
    if len(input_ids) > 1021:
        input_ids = input_ids[:1014] + input_ids[-7:] # The last 7 tokens are '</content> </theorem>'
        prompt = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids))
        
    prompt += ' <proof>'
    return prompt

# truncate proof history to 900 tokens
def truncate_history(history):
    input_ids = tokenizer(history)['input_ids']
    
    while len(input_ids) > 900:
        history = ' ' + history[history.find('\\n')+2:] # get rid of the oldest history
        input_ids = tokenizer(history)['input_ids']
        
    return history

def split_to_steps(proof):
    splits = proof.split('\\n')
    steps = []
    step = ''
    for split in splits:
        if len(split) == 0:
            continue
        if split[0].islower():
            step += '\\n' + split
        else:
            if len(step) > 0:
                steps.append(step)
            step = split
    if len(step) > 0:
        steps.append(step)
    return steps

def generate_full_proof(ckpt, item):
    prompt = truncate_prompt(item['prompt'])
    
    while True:
        try:
            completion = openai.Completion.create(
                model=ckpt,
                prompt=prompt,
                max_tokens=1020,
                temperature=0.0,
                stop='</proof>',
            )
            break
        except openai.error.RateLimitError as e:
            tqdm.write(str(e))
            tqdm.write("Retrying in 10 min ...")
            import time
            time.sleep(600)
        except Exception as e:
            tqdm.write(str(e))
            tqdm.write("Retrying in 10 min...")
            import time
            time.sleep(600)
            
    proof = completion['choices'][0]['text']
    proof = proof.strip(' ')

    return proof

def generate_next_steps(ckpt, item, orig):
    prompt = truncate_prompt(item['prompt'])
    
    gold_steps = split_to_steps(orig['target'])
    
    proof_lines = []
    
    for i, gold_step in enumerate(gold_steps):
        history = '' if i == 0 else (' ' + '\\n'.join(gold_steps[:i]) + '\\n')
        history = truncate_history(history)
        
        while True:
            try:
                # greedy
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    temperature=0.0,
                    stop=['\\n', '</proof>'],
                )
                greedy = completion['choices'][0]['text'].strip(' ')
                '''
                # beam search
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    best_of=20,
                    stop=['\\n', '</proof>'],
                )
                beam = completion['choices'][0]['text'].strip(' ')
                '''
                # sampling
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    n=10,
                    stop=['\\n', '</proof>'],
                )
                samples = [completion['choices'][i]['text'].strip(' ') for i in range(10)]
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    n=10,
                    top_p=0.9,
                    stop=['\\n', '</proof>'],
                )
                samples_p9 = [completion['choices'][i]['text'].strip(' ') for i in range(10)]
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    n=10,
                    top_p=0.7,
                    stop=['\\n', '</proof>'],
                )
                samples_p7 = [completion['choices'][i]['text'].strip(' ') for i in range(10)]
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    n=10,
                    top_p=0.5,
                    stop=['\\n', '</proof>'],
                )
                samples_p5 = [completion['choices'][i]['text'].strip(' ') for i in range(10)]
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    n=10,
                    temperature=0.8,
                    stop=['\\n', '</proof>'],
                )
                samples_t8 = [completion['choices'][i]['text'].strip(' ') for i in range(10)]
                completion = openai.Completion.create(
                    model=ckpt,
                    prompt=prompt + history,
                    max_tokens=120,
                    n=10,
                    temperature=0.6,
                    stop=['\\n', '</proof>'],
                )
                samples_t6 = [completion['choices'][i]['text'].strip(' ') for i in range(10)]

                break
            except openai.error.RateLimitError as e:
                tqdm.write(str(e))
                tqdm.write("Retrying in 10 min ...")
                import time
                time.sleep(600)
            except Exception as e:
                tqdm.write(str(e))
                tqdm.write("Retrying in 10 min...")
                import time
                time.sleep(600)

        line_output = {
            'greedy': greedy,
            'true': gold_step,
            'samples': samples,
            'samples_p9': samples_p9,
            'samples_p7': samples_p7,
            'samples_p5': samples_p5,
            'samples_t8': samples_t8,
            'samples_t6': samples_t6,
        }
        proof_lines.append(line_output)

    output = { 'proof_lines': proof_lines }
    return output


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, default='gpt3')
    parser.add_argument('--ckpt', type=str, required=True)
    parser.add_argument('--refs', choices=['norefs', 'gtrefs', 'retrefs'], required=True)
    parser.add_argument('--mode', choices=['fullgen', 'nextstep'], required=True)
    parser.add_argument('--reduced', action='store_true')
    parser.add_argument('--codename', type=str, required=True)
    args = parser.parse_args()

    with open(f'../data/latest/proofwiki__refs_ground_truth.json') as f:
        ds = json.load(f)

    gpt3_ds = {}
    for split in ['valid', 'test']:
        with open(f'data/latest/gpt3ft_proofwiki_{args.refs}.{split}.jsonl') as f:
            gpt3_ds[split] = [json.loads(line.strip('\n')) for line in f]

    if args.reduced:
        with open('../data/latest/reduced_ixs.txt') as f:
            reduced_ixs = [line.strip('\n').split(' ') for line in f]
        pairs = [(ds[ix[0]][int(ix[1])], gpt3_ds[ix[0]][int(ix[1])]) for ix in reduced_ixs]
    else:
        pairs = list(zip(ds['valid'], gpt3_ds['valid']))

    if args.mode == 'fullgen':

        full_generations = []
        for (orig, item) in tqdm(pairs):

            proof = generate_full_proof(args.ckpt, item)
    
            generation = {
                'metadata': orig['id'],
                'text': proof,
                'orig': orig,
            }
            full_generations.append(generation)
    
        outfile = f'eval/latest/{args.codename}.pkl'
        pickle.dump({
            'full_generations': full_generations,
            'name': args.name,
            'ckpt': args.ckpt
        }, open(outfile, 'wb'))

    elif args.mode == 'nextstep':

        nextstep_generations = []
        for (orig, item) in tqdm(pairs):

            output = generate_next_steps(args.ckpt, item, orig)
            nextstep_generations.append({ 'output': output, 'orig': orig })
    
        outfile = f'eval/latest/{args.codename}.pkl'
        pickle.dump({
            'next_step_generations': nextstep_generations,
            'name': args.name,
            'ckpt': args.ckpt
        }, open(outfile, 'wb'))


if __name__ == '__main__':
    main()

