import argparse
from accelerate import Accelerator   
from utils.data import get_prompt, get_tokenized_ds, get_dataset
from functools import partial
from utils.model import model_map, get_trained_model, run_generation_dp
from utils.log import LoggerCallback

accelerator = Accelerator()

def main(args):
    add_delim = True if "delim" in args.defense else False

    # Logging
    exp_dir = f"./exp/{args.ds}/{args.model}/{args.trainer}/{args.defense}/{args.suffix}"
     
    logger = LoggerCallback(exp_dir, "eval", use_accel=True)
    accelerator.wait_for_everyone()

    logger.log('Loading model, tokenizer...')

    device = 'cuda:' + str(accelerator.process_index)
    model, tokenizer = get_trained_model(logger.model_dir, args.model, bf16=True, is_train=False, device=device) 

    model.set_defense(args.defense)

    logger.log('Loading dataset...')
    ds_orig = get_dataset(args.ds_eval, split="test")
 
    
    if args.n_examples is None:
        args.n_examples = len(ds_orig)
    else:
        args.n_examples = min(args.n_examples, len(ds_orig))
    
    ds_orig = ds_orig.select(range(args.n_examples))

    logger.log('Constructing prompts, tokenizing...')
    ds = ds_orig.map(partial(get_prompt, include_response=False, format="sft"))
    ds = get_tokenized_ds(ds, tokenizer, model_map[args.model]["delimiters"], pad_to_multiple=args.batch_size * accelerator.num_processes, batch_size=args.batch_size, add_generation_prompt=True, add_delim=add_delim)  # pad, tokenize, batch

    logger.log('Generating responses...')
    resps = run_generation_dp(ds, tokenizer, model, accelerator)

    logger.log('Formatting responses...')
    eval_out = [] 
    if accelerator.is_local_main_process: 
        resps = resps[:len(ds_orig)]
        for example, resp in zip(ds_orig, resps):
            ex = {}
            ex["instruction"] = example["instruction"]
            if "input" in example and example["input"] != "":
                ex["instruction"] = example["instruction"] + " " + example["input"]
            ex["output"] = resp
            ex['generator'] = f"{args.model}_{args.trainer}_{args.defense}"
            eval_out.append(ex)
    
    logger.log(f'Saving results in {logger.eval_dir}/eval_out.json') 
    logger.dump_json(eval_out, f'{logger.eval_dir}/eval_out.json')

    logger.log('Done!')
    
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Eval")
    parser.add_argument('--model', type=str, default='llama3.2_3b', help='Name of the pre-trained model', choices=model_map.keys())
    parser.add_argument('--ds', type=str, default='alpaca', help='Name of the dataset used for training')
    parser.add_argument('--ds_eval', type=str, default='alpaca', help='Name of the dataset to use for evaluation')
    parser.add_argument('--trainer', type=str, default='dpo', help='Name of the trainer used', choices=['sft', 'dpo', 'instruct'])
    parser.add_argument('--defense', type=str, default='none', help='Defense to use', choices=["none", "delim", "ise", "air"])
    parser.add_argument("--batch_size", type=int, default=32, help="batch size")
    parser.add_argument('--n_examples', type=int, default=None, help='Number of examples to evaluate')
    parser.add_argument('--suffix', type=str, default='', help='Suffix for the experiment')
    args = parser.parse_args()
    
    main(args)
