from pathlib import Path
from attacks.data_extraction import parse_args, run_data_extraction
import re
from attacks.mia_utils import prepare_model
from src.better_tasks import get_preprocessed_dataset
import os


if __name__=="__main__":
    args = parse_args()
    print(args)

    output_dir = Path(args.out_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    os.makedirs(output_dir / "plots", exist_ok=True)

    if args.score_pretrained:
        model, tokenizer = prepare_model('EleutherAI/pythia-1b')
    else:
        model, tokenizer = prepare_model(args.init_checkpoint)
    tokenizer.pad_token = tokenizer.eos_token
    model = model.eval()
    
    try:
        model = model.to_bettertransformer()
    except:
        pass
    name_split = re.sub(r"(\d+)e-(\d+)", r"\1e~\2", args.init_checkpoint.split('/')[-1]).split('-')
    print(name_split)
    info = {}
    if len(name_split) < 10:
        name_split.append('0.1') # per_sample_max_grad_norm

    for i, pos in enumerate(['dataset', 'setting', 'lr', 'epoch-name', 'batch_size', 'target_epsilon', 'prefix_length', 'prefix_type', 'shadow_id', 'per_sample_max_grad_norm']):
        info[pos] = name_split[i].replace('~', '-')

    assert info['prefix_type']  != 'none'
    
    output = get_preprocessed_dataset(info['dataset'].lower(), args.data_cache_dir,
                                   tokenizer, args.max_seq_length, 
                                   int(info['shadow_id']), args.topk,
                                   info['prefix_type'], int(info['prefix_length']), args.ratio_change,
                                   z_ratio=0.1
                                   )
    output_none = get_preprocessed_dataset(info['dataset'].lower(), args.data_cache_dir,
                                   tokenizer, args.max_seq_length, 
                                   int(info['shadow_id']), args.topk,
                                   'none', int(info['prefix_length']), args.ratio_change,
                                   z_ratio=0.1
                                   )


    run_data_extraction(output, output_none, info, model, tokenizer, output_dir, args)
