
import argparse
import glob
import json
import os

import pandas as pd
import torch
import tqdm

import _settings
import dataeval.coqa_new as coqa
import dataeval.nq_open as nq_open
import dataeval.triviaqa as triviaqa
import models
import utils

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gpt-3.5-turbo')
parser.add_argument('--dataset', type=str, default='triviaqa')
parser.add_argument('--fraction_of_data_to_use', type=float, default=1.0)
parser.add_argument('--num_generations_per_prompt', type=int, default=20)
parser.add_argument('--seed', type=int, default=10)
parser.add_argument('--nprocess', type=int, default=None)


args = parser.parse_args()


_UNUSED_TOKENIZER = models.load_tokenizer()
def get_dataset_fn(data_name):
    if data_name == 'triviaqa':
        return triviaqa.preprocess_data
    if data_name == 'coqa':
        return coqa.preprocess_data
    if data_name == 'nq_open':
        return nq_open.preprocess_data

def get_generations(model_name:str, args, seed=10, old_sequences=None, task_runner:utils.TaskPartitioner=None):
    # ='gpt-3.5-turbo'
    dataset = get_dataset_fn(args.dataset)(_UNUSED_TOKENIZER)
    if args.fraction_of_data_to_use < 1.0:
        dataset = dataset.train_test_split(test_size=(1 - args.fraction_of_data_to_use), seed=seed)['train']
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    if old_sequences is None:
        old_sequences = []
    old_sequences = {_['id']: _ for _ in old_sequences}
    sequences = []
    for batch_idx, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)):
        if batch['id'][0] in old_sequences:
            sequences.append(old_sequences[batch['id'][0]])
            continue
        generated_texts = []
        for _ in range(args.num_generations_per_prompt):
            if task_runner is None:
                generated_texts.append(models.openai_query(batch['prompt'][0], model=model_name, attemptd_id=_, max_tries=50))
            else:
                task_runner.add_task(models.openai_query, batch['prompt'][0], model=model_name, attemptd_id=_, max_tries=50)
        if task_runner is not None:
            continue
        curr_seq = dict(
                prompt=batch['prompt'][0],
                id=batch['id'][0], #NOTE: This changed
                question=batch['question'][0],
                answer=batch['answer'][0], #NOTE: This changed
                additional_answers=[],
        )
        curr_seq.update(
                dict(
                    generations=generated_texts,
                )
            )

        if args.dataset == 'coqa':
            curr_seq['additional_answers'] = [x[0] for x in batch['additional_answers']]
        sequences.append(curr_seq)
    return task_runner or sequences

def main(overwrite=True, continue_from=None, parallel=None):
    task_runner = None if parallel is None else utils.TaskPartitioner()
    if continue_from:
        fname = os.path.basename(continue_from)
        args.__dict__ = utils.jload(continue_from.replace(fname, 'args'+fname.replace("_partial.pkl", ".json")))
        old_sequences = pd.read_pickle(continue_from)
        cache_dir = os.path.dirname(continue_from)
        run_id = int(os.path.basename(continue_from).replace("_partial.pkl", ""))
        model_name = args.model
    else:
        old_sequences = []
        model_name = args.model
        if '/' in model_name:
            model_name = model_name.replace('/', '_')
        cache_dir = os.path.join(_settings.GENERATION_FOLDER, f'{model_name}_{args.dataset}_{args.seed}')
        os.makedirs(cache_dir, exist_ok=True)
        old_results = glob.glob(os.path.join(cache_dir, '*.pkl'))
        old_results = [_ for _ in old_results if '_partial' not in _]
        if len(old_results) > 0 and not overwrite:
            print(f'Found {len(old_results)} generations in {cache_dir}.')
            return
        run_id = len(old_results)


        with open(os.path.join(cache_dir, f'args{run_id}.json'), 'w') as f:
            json.dump(args.__dict__, f)
    print(f'Generating {args.num_generations_per_prompt} generations per prompt for {model_name} on {args.dataset}...')
    print(f"Saving to {os.path.join(cache_dir, f'{run_id}.pkl')}")
    sequences = get_generations(model_name, args, seed=args.seed, old_sequences=old_sequences, task_runner=task_runner)
    if task_runner is not None:
        return task_runner.run_multi_process(parallel)
    print(f'Writing {len(sequences)} generations to {cache_dir}...')
    pd.to_pickle(sequences, os.path.join(cache_dir, f'{run_id}.pkl'))
    return


if __name__ == '__main__':
    task_runner = main(parallel=args.nprocess)