import pretraining_attribution
import argparse
import json

parser = argparse.ArgumentParser(description="Description of your program")

parser.add_argument('--model', type=str, default="llama", help='Model name')
parser.add_argument('--n_params', type=str, default="7b", help='String with number of params, e.g. "7b"')
parser.add_argument('--dataset', type=str, default="model_written_evals", help='Dataset name')
parser.add_argument('--n_samples', default=16, type=int, help='number of samples to generate per prompt')
parser.add_argument('--max_new_tokens', default=20, type=int, help='num tokens in generation')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
parser.add_argument('--save_freq', default=200, type=int, help='Frequency of saving')
parser.add_argument('--n_prompts', default=None, type=lambda x: x if type(x) == int else None, help='Number of prompts (can be None)')
parser.add_argument('--dtype', type=str, default='float16', help='Data type (default: float16)')
parser.add_argument('--no_formatting', action='store_true', help='Disable model-native chat formatting')
parser.add_argument('--use_sysprompt', action='store_true', help='Use system prompt if available')
parser.add_argument('--idx', default=0, type=int, help='Index among quantifier splits')
parser.add_argument('--n_splits', default=1, type=int, help='Divide quantifiers into many jobs')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
parser.add_argument('--multi_gpu', action='store_true', help='Spread model across GPUs (very slow)')

args = parser.parse_args()

print(json.dumps(args.__dict__, indent=4))

experiment = pretraining_attribution.Experiment(args)
experiment.dump_model_written_evals_logprobs(idx=args.idx, n_splits=args.n_splits, single_gpu=not args.multi_gpu)