import argparse
from trl import DataCollatorForCompletionOnlyLM
from datasets import load_from_disk
import warnings
from utils.data import get_prompt, get_tokenized_ds
from utils.model import model_map, get_ptm
from utils.trainer import get_trainer_config
from utils.log import LoggerCallback
from functools import partial
from accelerate import Accelerator
warnings.filterwarnings('ignore')
accelerator = Accelerator()

def main(args):
    add_delim = True if "delim" in args.defense else False

    # Logging
    exp_dir = f"./exp/{args.ds}/{args.model}/instruct/{args.defense}/{args.suffix}"
    logger = LoggerCallback(exp_dir, "train", use_accel=True)
    accelerator.wait_for_everyone()
    
    logger.log('Loading model, tokenizer...') 
    model, tokenizer = get_ptm(args.model)
    model.set_defense(args.defense)

    logger.log('Loading dataset...')
    dataset = load_from_disk(f"./datasets/{args.ds}/clean/train")

    logger.log('Constructing prompts and tokenizing...')
    dataset = dataset.map(partial(get_prompt, include_response=True, format="sft"), remove_columns=dataset.column_names) 
    dataset = get_tokenized_ds(dataset, tokenizer, model_map[args.model]["delimiters"], add_generation_prompt=False, pad_length=args.max_seq_len, add_delim=add_delim)

    logger.log('Configuring Trainer...')
    Trainer, config = get_trainer_config("sft", logger.model_dir, args)
    data_collator = DataCollatorForCompletionOnlyLM(
        response_template=model_map[args.model]["delimiters"][-1],
        tokenizer=tokenizer,
    )
    trainer = Trainer(
        model=model,
        args=config,
        train_dataset=dataset,
        data_collator=data_collator,
        processing_class=tokenizer,
    )
    trainer.add_callback(logger)
    accelerator.wait_for_everyone()

    logger.log('Training...')
    trainer.train()

    logger.log(f'Saving model at {logger.model_dir}') 
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") if trainer.is_fsdp_enabled else None
    trainer.save_model(logger.model_dir)

    logger.log('Done!')
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train")
    parser.add_argument('--model', type=str, default='llama3.2_3b', help='Name of the pre-trained model', choices=model_map.keys())
    parser.add_argument('--ds', type=str, default='alpaca', help='Name of the dataset to use')
    parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size per device during training')
    parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
    parser.add_argument('--max_seq_len', type=int, default=768, help='Maximum sequence length')
    parser.add_argument('--max_steps', type=int, default=-1, help='Maximum number of training steps')
    parser.add_argument('--warmup_ratio', type=float, default=0.1, help='Warmup ratio')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
    parser.add_argument('--defense', type=str, default='none', help='Defense to use', choices=["none", "delim", "ise", "air"])
    parser.add_argument('--suffix', type=str, default='', help='Suffix for the experiment')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    args = parser.parse_args()
    main(args)
