from argparse import ArgumentParser, Namespace
import os
import yaml
from peft import PeftModel, LoraConfig, get_peft_model
import datasets
from transformers import TrainingArguments

import utils
from thirdparty.tofu.dataloader import CustomTrainer, CustomTrainerForgetting
from thirdparty.tofu.data_module import TextDatasetQA, TextForgetDatasetQA
from thirdparty.tofu.data_module import custom_data_collator as tofu_data_collator
from thirdparty.tofu.dataloader import custom_data_collator_forget as tofu_data_collator_forget
from data_modules.base_data import load_tofu_train_dataset, load_arxiv_train_dataset, custom_data_collator_arxiv as arxiv_data_collator, custom_data_collator_forget as arxiv_data_collator_forget
from data_modules.data_module import UnwatermarkedTextDataset, UnwatermarkedTextForgetDataset
from train import load_model_and_tokenizer, load_training_arguments, find_all_linear_names
from unlearning_methods import  scrub, scr_newton

# Add these lines to disable optimized attention
import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)


UNLEARN_METHODS = [
    'original', 'retraining', 'finetune',
    'grad_ascent', 'grad_diff', 
    'idk', 'scrub', 'npo',
    'scr_newton'
]

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--dataset_name', type=str, choices=['tofu', 'arxiv'],
                        help='Dataset name')
    parser.add_argument('--data_config_path', type=str,
                        help='Path to dataset and split config')
    parser.add_argument('--unlearn_config_path', type=str,
                        help='Path to unlearning config')
    parser.add_argument('--orig_model_path', type=str,
                        help='Path to model that contains trainable parameters (peft parameters)')
    parser.add_argument('--unlearn_method', type=str, choices=UNLEARN_METHODS,
                        help='Unlearning method')
    parser.add_argument('--num_epochs', type=int, default=None,
                        help='Number of epochs for unlearning (finetuning)')
    parser.add_argument("--unlearn_batch_size", type=int, default=100, 
                        help="Number of points to unlearn in batch")
    parser.add_argument('--output_dir', type=str, default='results/',
                        help='Directory to save results and models')
    parser.add_argument('--no_save', action='store_true',
                        help='Not save the unlearned model (for debugging purpose)')
    args = parser.parse_args()
    utils.set_seed(args.seed)
    print(f'Unlearning method: {args.unlearn_method}')

    # load data config
    with open(args.data_config_path, 'r') as f:
        data_config = Namespace(**yaml.safe_load(f))
    print('data_config:', vars(data_config))

    # load unlearning config
    with open(args.unlearn_config_path, 'r') as f:
        config = yaml.safe_load(f)
        if args.unlearn_method in config:   # update method-specific hyperparameters
            config.update(config[args.unlearn_method])
        config = Namespace(**config)

    print('unlearn_config:', vars(config))
    
    # load dataset and splits
    if args.dataset_name == 'tofu':
        train_data, forget_data, retain_data = load_tofu_train_dataset(**vars(data_config))
        data_collator = tofu_data_collator
    elif args.dataset_name == 'arxiv':
        train_data, forget_data, retain_data = load_arxiv_train_dataset(**vars(data_config))
        data_collator = arxiv_data_collator
    else:
        raise NotImplementedError
    print('num_train_rows:', len(train_data))
    print('num_forget_rows:', len(forget_data))
    print('num_retain_rows:', len(retain_data))

    # create and load finetuned model and tokenizer
    model, tokenizer = load_model_and_tokenizer(config)
    if args.unlearn_method != 'retraining':
        print(f'Loading trainable parameters from {args.orig_model_path}')
        # utils.load_trainable_model(model, args.orig_model_path)
        model = PeftModel.from_pretrained(model, args.orig_model_path, is_trainable=True)
    else:
        lora_config = LoraConfig(
            target_modules=find_all_linear_names(model),
            modules_to_save=["embed_tokens", "lm_head"],
            **config.lora,
        )
        model = get_peft_model(model, lora_config)
    
    utils.unfreeze_lora_parameters(model)  # unfreeze LoRA parameters for unlearning
     
     
    print("========= SEQUENTIAL UNLEARNING ==========")
    unlearn_round = 0
    remaining_forget_data = forget_data
    while len(remaining_forget_data) > 0:
        unlearn_round += 1
        print("unlearn_round:", unlearn_round)
        
        curr_forget_data = remaining_forget_data[:args.unlearn_batch_size]
        curr_forget_data = datasets.Dataset.from_dict(curr_forget_data)
        remaining_forget_data = remaining_forget_data[args.unlearn_batch_size:]
        remaining_forget_data = datasets.Dataset.from_dict(remaining_forget_data)
        curr_retain_data = datasets.concatenate_datasets([retain_data, remaining_forget_data])
        
        print("forget_size:", len(curr_forget_data))
        print("retain_size:", len(curr_retain_data))
        
        # create necessary datasets for the unlearner
        if args.unlearn_method in ('retraining', 'finetune'):
            train_data = curr_retain_data
            if args.dataset_name == 'tofu':
                train_data = TextDatasetQA(train_data,
                                            tokenizer,
                                            model_family=config.model_family,
                                            max_length=config.max_seq_length,
                                            question_key='question',
                                            answer_key='answer')
            elif args.dataset_name == 'arxiv':
                train_data = UnwatermarkedTextDataset(train_data, 
                                                    tokenizer,
                                                    max_length=config.max_seq_length)
            if args.unlearn_method in ('retraining'):
                model, tokenizer = load_model_and_tokenizer(config)
                lora_config = LoraConfig(
                    target_modules=find_all_linear_names(model),
                    modules_to_save=["embed_tokens", "lm_head"],
                    **config.lora,
                )
                model = get_peft_model(model, lora_config)
                utils.unfreeze_lora_parameters(model)
                model.config.use_cache = False
        elif args.unlearn_method in ('scr_newton'):
            oracle_model = None
            if args.dataset_name == 'tofu':
                _forget_data = TextDatasetQA(curr_forget_data,
                                            tokenizer,
                                            model_family=config.model_family,
                                            max_length=config.max_seq_length,
                                            question_key='question',
                                            answer_key='answer')
                _retain_data = TextDatasetQA(curr_retain_data,
                                            tokenizer,
                                            model_family=config.model_family,
                                            max_length=config.max_seq_length,
                                            question_key='question',
                                            answer_key='answer')
            elif args.dataset_name == 'arxiv':
                forget_data = UnwatermarkedTextDataset(curr_forget_data, 
                                                    tokenizer,
                                                    max_length=config.max_seq_length)
                retain_data = UnwatermarkedTextDataset(curr_retain_data, 
                                                    tokenizer,
                                                    max_length=config.max_seq_length)

        else:
            oracle_model = None
            forget_loss = args.unlearn_method
            if args.unlearn_method in ('scrub', 'npo'):
                oracle_model, _ = load_model_and_tokenizer(config)
                # oracle_model.resize_token_embeddings(len(tokenizer.get_vocab()) + 4)
                oracle_model = PeftModel.from_pretrained(oracle_model, args.orig_model_path)    # assign copy of model's parameters to oracle_model
            if args.dataset_name == 'tofu':
                data_collator = tofu_data_collator_forget
                train_data = TextForgetDatasetQA(
                    curr_forget_data,
                    curr_retain_data,
                    tokenizer,
                    model_family=config.model_family,
                    max_length=config.max_seq_length,
                    loss_type=forget_loss,
                )
            elif args.dataset_name == 'arxiv':
                data_collator = arxiv_data_collator_forget
                train_data = UnwatermarkedTextForgetDataset(
                        curr_forget_data,
                        curr_retain_data,
                        tokenizer,
                        max_length=config.max_seq_length,
                        loss_type=forget_loss,
                )

        # start unlearning

        training_args = load_training_arguments(args, config, num_train_samples=len(train_data))
        model.config.use_cache = False      # disable KV cache

        if args.unlearn_method in ('retraining', 'finetune'):
            if args.num_epochs is not None:
                training_args.num_train_epochs = args.num_epochs
            trainer = CustomTrainer(
                model=model,
                train_dataset=train_data,
                args=training_args,
                data_collator=data_collator,
            )
            trainer.train()
        
        elif args.unlearn_method == 'scr_newton':
            trainer_init_func = CustomTrainer
            trainer_init_kwargs = Namespace(
                model=model,
                train_dataset=train_data,
                args=training_args,
                data_collator=data_collator,
            )
            
            scr_newton.unlearn(model, _forget_data, _retain_data, config, trainer_init_func, trainer_init_kwargs, model.device)
        else:
            trainer = CustomTrainerForgetting(
                model=model,
                tokenizer=tokenizer,
                train_dataset=train_data,
                compute_metrics=None,   # the callback for computing metrics, None in this case since you're doing it in your callback
                args=training_args,
                data_collator=data_collator,
                oracle_model=oracle_model,
                forget_loss=forget_loss,
                eval_cfg=None,      # turn off evaluate during unlearning
            )
            if args.unlearn_method == 'scrub':
                scrub.unlearn(trainer, config)
            else:
                trainer.train()

        if not args.no_save:
            if args.unlearn_method in ('scr_newton'):
                model.save_pretrained(f'{args.output_dir}-{unlearn_round}')
            else:
                trainer.model.save_pretrained(f'{args.output_dir}-{unlearn_round}')
            print(f'Saved unlearned model to {args.output_dir}-{unlearn_round}')
