from argparse import ArgumentParser, Namespace
import os
import yaml
from peft import PeftModel, LoraConfig, get_peft_model

import utils
from thirdparty.tofu.dataloader import CustomTrainer, CustomTrainerForgetting
from thirdparty.tofu.data_module import TextDatasetQA, TextForgetDatasetQA
from thirdparty.tofu.data_module import custom_data_collator as tofu_data_collator
from thirdparty.tofu.dataloader import custom_data_collator_forget as tofu_data_collator_forget
from data_modules.base_data import load_tofu_train_dataset, load_arxiv_train_dataset, custom_data_collator_arxiv as arxiv_data_collator, custom_data_collator_forget as arxiv_data_collator_forget
from data_modules.data_module import UnwatermarkedTextDataset, UnwatermarkedTextForgetDataset
from train import load_model_and_tokenizer, load_training_arguments, find_all_linear_names
from unlearning_methods import scrub, scr_newton

# Add these lines to disable optimized attention
import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)


UNLEARN_METHODS = [
    'original', 'retraining', 'finetune',
    'grad_ascent', 'grad_diff', 
    'idk', 'scrub', 'npo', 
    'scr_newton'
]


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--dataset_name', type=str, choices=['tofu', 'arxiv'],
                        help='Dataset name')
    parser.add_argument('--data_config_path', type=str,
                        help='Path to dataset and split config')
    parser.add_argument('--unlearn_config_path', type=str,
                        help='Path to unlearning config')
    parser.add_argument('--orig_model_path', type=str,
                        help='Path to model that contains trainable parameters (peft parameters)')
    parser.add_argument('--unlearn_method', type=str, choices=UNLEARN_METHODS,
                        help='Unlearning method')
    parser.add_argument('--num_epochs', type=int, default=None,
                        help='Number of epochs for unlearning (finetuning)')
    parser.add_argument('--output_dir', type=str, default='results/',
                        help='Directory to save results and models')
    parser.add_argument('--no_save', action='store_true',
                        help='Not save the unlearned model (for debugging purpose)')
    args = parser.parse_args()
    utils.set_seed(args.seed)
    print(f'Unlearning method: {args.unlearn_method}')

    # load data config
    with open(args.data_config_path, 'r') as f:
        data_config = Namespace(**yaml.safe_load(f))
    print('data_config:', vars(data_config))

    # load unlearning config
    with open(args.unlearn_config_path, 'r') as f:
        config = yaml.safe_load(f)
        if args.unlearn_method in config:   # update method-specific hyperparameters
            config.update(config[args.unlearn_method])
        config = Namespace(**config)

    print('unlearn_config:', vars(config))
    
    # load dataset and splits
    if args.dataset_name == 'tofu':
        train_data, forget_data, retain_data = load_tofu_train_dataset(**vars(data_config))
        data_collator = tofu_data_collator
    elif args.dataset_name == 'arxiv':
        train_data, forget_data, retain_data = load_arxiv_train_dataset(**vars(data_config))
        data_collator = arxiv_data_collator
    else:
        raise NotImplementedError
    print('num_train_rows:', len(train_data))
    print('num_forget_rows:', len(forget_data))
    print('num_retain_rows:', len(retain_data))

    # create and load finetuned model and tokenizer
    model, tokenizer = load_model_and_tokenizer(config)
    if args.unlearn_method != 'retraining':
        print(f'Loading trainable parameters from {args.orig_model_path}')
        # utils.load_trainable_model(model, args.orig_model_path)
        model = PeftModel.from_pretrained(model, args.orig_model_path, is_trainable=True)
    else:
        lora_config = LoraConfig(
            target_modules=find_all_linear_names(model),
            modules_to_save=["embed_tokens", "lm_head"],
            **config.lora,
        )
        model = get_peft_model(model, lora_config)
    
    utils.unfreeze_lora_parameters(model)  # unfreeze LoRA parameters for unlearning
     
    # create necessary datasets for the unlearner
    if args.unlearn_method in ('retraining', 'finetune'):
        train_data = retain_data
        if args.dataset_name == 'tofu':
            train_data = TextDatasetQA(train_data,
                                        tokenizer,
                                        model_family=config.model_family,
                                        max_length=config.max_seq_length,
                                        question_key='question',
                                        answer_key='answer')
        elif args.dataset_name == 'arxiv':
            train_data = UnwatermarkedTextDataset(train_data, 
                                                  tokenizer,
                                                  max_length=config.max_seq_length)
    elif args.unlearn_method in ('scr_newton'):
        oracle_model = None
        if args.dataset_name == 'tofu':
            forget_data = TextDatasetQA(forget_data,
                                        tokenizer,
                                        model_family=config.model_family,
                                        max_length=config.max_seq_length,
                                        question_key='question',
                                        answer_key='answer')
            retain_data = TextDatasetQA(retain_data,
                                        tokenizer,
                                        model_family=config.model_family,
                                        max_length=config.max_seq_length,
                                        question_key='question',
                                        answer_key='answer')
        elif args.dataset_name == 'arxiv':
            forget_data = UnwatermarkedTextDataset(forget_data, 
                                                  tokenizer,
                                                  max_length=config.max_seq_length)
            retain_data = UnwatermarkedTextDataset(retain_data, 
                                                  tokenizer,
                                                  max_length=config.max_seq_length)

    else:
        oracle_model = None
        forget_loss = args.unlearn_method
        if args.unlearn_method in ('scrub', 'npo'):
            oracle_model, _ = load_model_and_tokenizer(config)
            oracle_model = PeftModel.from_pretrained(oracle_model, args.orig_model_path)    # assign copy of model's parameters to oracle_model
        if args.dataset_name == 'tofu':
            data_collator = tofu_data_collator_forget
            train_data = TextForgetDatasetQA(
                forget_data,
                retain_data,
                tokenizer,
                model_family=config.model_family,
                max_length=config.max_seq_length,
                loss_type=forget_loss,
            )
        elif args.dataset_name == 'arxiv':
            data_collator = arxiv_data_collator_forget
            train_data = UnwatermarkedTextForgetDataset(
                    forget_data,
                    retain_data,
                    tokenizer,
                    max_length=config.max_seq_length,
                    loss_type=forget_loss,
            )

    # start unlearning
    import time
    start_time = time.time()
    torch.cuda.reset_peak_memory_stats()
    training_args = load_training_arguments(args, config, num_train_samples=len(train_data))
    model.config.use_cache = False      # disable KV cache
    if args.unlearn_method in ('retraining', 'finetune'):
        if args.num_epochs is not None:
            training_args.num_train_epochs = args.num_epochs
        trainer = CustomTrainer(
            model=model,
            train_dataset=train_data,
            args=training_args,
            data_collator=data_collator,
        )
        trainer.train()
    
    elif args.unlearn_method == 'scr_newton':
        trainer_init_func = CustomTrainer
        trainer_init_kwargs = Namespace(
            model=model,
            train_dataset=train_data,
            args=training_args,
            data_collator=data_collator,
        )
        
        scr_newton.unlearn(model, forget_data, retain_data, config, trainer_init_func, trainer_init_kwargs, model.device)

    else:
        trainer = CustomTrainerForgetting(
            model=model,
            tokenizer=tokenizer,
            train_dataset=train_data,
            compute_metrics=None, 
            args=training_args,
            data_collator=data_collator,
            oracle_model=oracle_model,
            forget_loss=forget_loss,
            eval_cfg=None,      # turn off evaluate during unlearning
        )
        if args.unlearn_method == 'scrub':
            scrub.unlearn(trainer, config)
        else:
            trainer.train()
    
    end_time = time.time()
    max_memory = torch.cuda.max_memory_allocated() / 1024**2  # in MB
    print(f'Unlearning completed in {(end_time - start_time)} seconds.')
    with open(os.path.join(args.output_dir, 'unlearning_time.txt'), 'w') as f:
        f.write(f'{(end_time - start_time)} seconds\n')
    with open(os.path.join(args.output_dir, 'memory.txt'), 'w') as f:
        f.write(f'{max_memory} MB\n')

    if not args.no_save:
        if args.unlearn_method in ('scr_newton'):
            model.save_pretrained(args.output_dir)
        else:
            trainer.model.save_pretrained(args.output_dir)
        print(f'Saved unlearned model to {args.output_dir}')
