import argparse
from trl import DataCollatorForCompletionOnlyLM
from datasets import load_from_disk
import warnings
from utils.data import get_prompt, get_tokenized_ds
from utils.model import model_map, get_trained_model, get_my_peft_model, save_model
from utils.trainer import get_trainer_config
from utils.log import LoggerCallback
from functools import partial
from copy import deepcopy

warnings.filterwarnings('ignore')

def main(args):
    add_delim = True if "delim" in args.defense else False

    # Logging
    exp_dir = f"./exp/{args.ds}/{args.model}/{args.trainer}/{args.defense}/{args.suffix}"
    logger = LoggerCallback(exp_dir, "train", use_accel=True)
    
    logger.log('Loading instr tuned model, tokenizer...') 
    instr_model_dir = f"./exp/{args.ds}/{args.model}/instruct/{args.defense}/{args.instr_suffix}/model"
    model, tokenizer = get_trained_model(instr_model_dir, args.model, bf16=False, is_train=True)

    model.set_defense(args.defense)
    ref_model = deepcopy(model)

    logger.log('Loading dataset...')
    dataset = load_from_disk(f"./datasets/{args.ds}/adv/train_{args.trainer}")
    
    logger.log('Constructing prompts, tokenizing...')
    dataset = dataset.map(partial(get_prompt, include_response=True, format=args.trainer), remove_columns=dataset.column_names) 
    dataset = get_tokenized_ds(dataset, tokenizer, model_map[args.model]["delimiters"], add_generation_prompt=False, pad_length=args.max_seq_len, add_delim=add_delim)
    
    if args.peft:
        logger.log('Configuring PEFT...')
        model = get_my_peft_model(model=model)

    logger.log('Configuring Trainer...')
    Trainer, config = get_trainer_config(args.trainer, logger.model_dir, args)

    data_collator = DataCollatorForCompletionOnlyLM(response_template=model_map[args.model]["delimiters"][-1], tokenizer=tokenizer) if args.trainer == "sft" else None
    if args.trainer == "dpo":
        trainer = Trainer(
            model=model,
            ref_model=ref_model,
            args=config,
            train_dataset=dataset,
            data_collator=data_collator,
            processing_class=tokenizer,
        )
    else:
        trainer = Trainer(
            model=model,
            args=config,
            train_dataset=dataset,
            data_collator=data_collator,
            processing_class=tokenizer, 
        )
    
    trainer.add_callback(logger)

    logger.log('Training...')
    trainer.train()

    logger.log(f'Saving model at {logger.model_dir}') 
    save_model(trainer=trainer, args=args, instr_model_dir=instr_model_dir, model_dir=logger.model_dir, tokenizer=tokenizer)
    logger.log('Done!')
 

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train")
    parser.add_argument('--model', type=str, default='llama3.2_3b', help='Name of the pre-trained model', choices=model_map.keys())
    parser.add_argument('--ds', type=str, default='alpaca', help='Name of the dataset to use')
    parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size per device during training')
    parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
    parser.add_argument('--max_seq_len', type=int, default=768, help='Maximum sequence length')
    parser.add_argument('--peft', action='store_true', help='Enable PEFT training')
    parser.add_argument('--max_steps', type=int, default=-1, help='Maximum number of training steps')
    parser.add_argument('--warmup_ratio', type=float, default=0.1, help='Warmup ratio')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
    parser.add_argument('--trainer', type=str, default='dpo', help='Trainer to use', choices=['sft', 'dpo'])
    parser.add_argument('--defense', type=str, default='none', help='Defense to use', choices=["none", "delim", "ise", "air"])
    parser.add_argument('--suffix', type=str, default='', help='Suffix for the experiment')
    parser.add_argument('--instr_suffix', type=str, default='', help='Suffix for the instruct model')
    parser.add_argument('--seed', type=int, default=2025, help='Seed for the experiment')

 
    args = parser.parse_args()
    main(args)
