import pretraining_attribution
import argparse
import json

parser = argparse.ArgumentParser(description="Description of your program")

parser.add_argument('--model', type=str, default="llama", help='Model name')
parser.add_argument('--interpolation_type', type=str, default="alpha_scaling", help='String with number of params, e.g. "7b"')
parser.add_argument('--n_params', type=str, default="7b", help='String with number of params, e.g. "7b"')
parser.add_argument('--dataset', type=str, default="model_written_evals", help='Dataset name')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
parser.add_argument('--save_freq', default=200, type=int, help='Frequency of saving')
parser.add_argument('--n_prompts', default=None, type=str, help='Number of prompts (can be None)')
parser.add_argument('--dtype', type=str, default='float16', help='Data type (default: float16)')
parser.add_argument('--use_sysprompt', action='store_true', help='Use system prompt if available')
parser.add_argument('--formatting', action='store_true', help='No model-specific formatting')
parser.add_argument('--idx', default=0, type=int, help='Index among quantifier splits')
parser.add_argument('--n_splits', default=1, type=int, help='Divide quantifiers into many jobs')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
parser.add_argument('--multi_gpu', action='store_true', help='Spread model across GPUs (very slow)')
parser.add_argument('--only_alpha_1', action='store_true', help='Only use alpha = 1') 
# parser.add_argument('--multi_gpu', action='store_true', help='Spread model across GPUs (very slow)')

args = parser.parse_args()

setattr(args, "n_prompts", int(args.n_prompts) if args.n_prompts is not None else None)
setattr(args, "no_formatting", not args.formatting)
# setattr(args, "use_sysprompt", False)

print(json.dumps(args.__dict__, indent=4))

grid = [0.75, 0.9, 0.95, 1.0, 1.05, 1.1, 1.25]
if args.only_alpha_1:
    grid = [1.0]

experiment = pretraining_attribution.Experiment(args)
experiment.dump_msj_attributions(
    # mode=args.interpolation_type,
    idx=args.idx, 
    n_splits=args.n_splits, 
    grid=grid
)

