import copy
import torch
import numpy
import argparse
import os
import json
import train_func
import methods
import settings
import utils
import gc
import logging
import time

if __name__ == '__main__':
    UNLEARN_METHODS = ['retraining', 'negate_loss', 'random_label', 'gd', 'sgd', 'fisher_newton', 'cr_newton', 'scr_newton', 'pinv_newton', 'damp_newton', 'lbfgs', 'ntk']
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, help='Dataset name')
    parser.add_argument('--seed', type=int, default=1, help='Random seed')
    parser.add_argument('--llama', action="store_true", help="Whether the model is llama")
    parser.add_argument('--divide', type=int, default=1, help='How many continual rounds')
    parser.add_argument('--unlearn_method', type=str, choices=UNLEARN_METHODS, help='unlearning method')
    parser.add_argument('--unlearn_batch_size', type=int, default=1, help='number of points to unlearn in batch')
    parser.add_argument('--sel_level', choices=['class', 'instance'])
    parser.add_argument('--sel_freq', type=float)
    parser.add_argument('--sel_sort_by', choices=['random', 'lowest_loss_first', 'highest_loss_first'])
    args = parser.parse_args()
    
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    torch.manual_seed(args.seed)
    numpy.random.seed(args.seed)
    n_divide = args.divide
    
    logger = logging.getLogger(__name__)

    save_dir = f'continual_unlearning/{args.unlearn_method}_results'
    save_dir = os.path.join(save_dir, f'{args.dataset}_{args.seed}')
    train_config_path = f'settings/train_config/{args.dataset}.json'
    os.makedirs(save_dir, exist_ok=True)
    
    if args.dataset in ['imdb', 'samsum', 'xsum', 'agnews']:
        (entire_train_set, entire_test_set), (model, tokenizer), model_cfg = settings.prepare_data_and_model(args.dataset)
    else:
        (entire_train_set, entire_test_set), model, model_cfg = settings.prepare_data_and_model(args.dataset)

    init_state_dict = copy.deepcopy(model.state_dict())
    print('no_params (trainable): %d' % sum(p.numel() for p in model.parameters() if p.requires_grad))
    print('train_size:', len(entire_train_set))
    print('test_size:', len(entire_test_set))
    
    for i in range(n_divide):
        print(f'Training for {i}-th continual unlearning process')
        train_set = entire_train_set.shard(num_shards=n_divide, index=i)
        test_set = entire_test_set.shard(num_shards=n_divide, index=i)

        if i != 0:
            unlearned_model_path = f'{save_dir}/{args.unlearn_method}_{i-1}.pth'
            if args.dataset in ['imdb', 'samsum', 'xsum', 'agnews']:
                (_, _), (model, _), _ = settings.prepare_data_and_model(args.dataset)
            else:
                (_, _), model, _ = settings.prepare_data_and_model(args.dataset)
            original_model_state_dict = torch.load(open(unlearned_model_path, 'rb'))['state_dict']
            model.load_state_dict(original_model_state_dict, strict=False)
            del original_model_state_dict
            os.remove(unlearned_model_path)
        
        # ======================== Training ===============================
        if args.dataset in ['imdb', 'samsum', 'xsum', 'agnews']:
            from adapters import AdapterTrainer
            from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
            from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, LoraConfig
            from trl import SFTTrainer

            train_cfg = json.load(open(train_config_path, 'r'))
            print(train_cfg)
            training_args = TrainingArguments(output_dir=save_dir, **train_cfg)

            data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_set,
                eval_dataset=test_set,
                tokenizer=tokenizer,
                data_collator=data_collator, # this will dynamically pad examples in each batch to be equal length
                compute_metrics=utils.compute_metrics,
            )

            train_results = trainer.train()
            trainer.log_metrics('train', train_results.metrics)
            torch.cuda.empty_cache()
            gc.collect()
            with torch.no_grad():
                eval_metrics = trainer.evaluate()
            trainer.log_metrics('eval', eval_metrics)
        else:
            train_cfg = argparse.Namespace(**json.load(open(train_config_path, 'r')))
            print(train_cfg)
            train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=train_cfg.batch_size)
            test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=train_cfg.batch_size)
            model.to(DEVICE)
            loss_fn = torch.nn.CrossEntropyLoss()

            print('========= TRAINING ==========')
            if train_cfg.optim == 'sgd':
                optimizer = torch.optim.SGD(model.parameters(), lr=train_cfg.learning_rate, weight_decay=train_cfg.weight_decay)
            elif train_cfg.optim == 'adam':
                optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg.learning_rate, weight_decay=train_cfg.weight_decay)
            scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.5)

            for epoch in range(1, train_cfg.n_epoch + 1):
                model.train()
                print('Epoch {}:'.format(epoch))
                for i, batch in enumerate(train_loader):
                    verbose = True if i % 150 == 0 else False
                    train_func.generic_step(model, loss_fn, batch, optimizer, verbose=verbose)
                if epoch % train_cfg.lr_update_interval == 0:
                    scheduler.step()

                model.eval()
                if epoch % 5 == 0 or (epoch == train_cfg.n_epoch):
                    print('train_acc:', train_func.test(model, train_loader), end='\n')
                print('test_acc:', train_func.test(model, test_loader), end='\n')

                w_norm = []
                for p in model.parameters():
                    if p.requires_grad:
                        w_norm.append(round(torch.linalg.norm(p).item(), 3))
                print('w_norm:', sum(w_norm))
        
        
        # ======================== Select Points ===============================
        print(f'Selecting Points for {i}-th continual unlearning process')
        
        def get_sel_size(sel_freq, tot_samples):
            if sel_freq == -1:
                sel_size = tot_samples
            elif sel_freq < 1:
                sel_size = int(tot_samples * sel_freq)
            else: sel_size = int(sel_freq)
            return sel_size
        
        if args.sel_level == 'instance':
            tot_samples = len(train_set)
            sample_indices = list(range(tot_samples))
            sel_size = get_sel_size(args.sel_freq, tot_samples)
        elif args.sel_level == 'class':
            raise Exception()
        
        if args.sel_sort_by == 'random':
            sel = numpy.random.choice(sample_indices, size=sel_size, replace=False)
        else:
            raise Exception()
        
        print('size:', len(sel))
        save_path = f'{save_dir}/unlearn_order_{i}.pth'
        torch.save(sel, save_path)
        print('Saved unlearn order to %s' % save_path)
        
        
        # ======================== Unlearning ===============================
        print(f'Unlearning for {i}-th continual unlearning process')
        
        def get_trainable_state_dict(model):
            trainable_layers = [pname for pname, p in model.named_parameters() if p.requires_grad]
            res = {}
            for k, v in model.state_dict().items():
                if k in trainable_layers:
                    res[k] = v
            return res
             
        train_config_path = f'settings/train_config/{args.dataset}.json'
        unlearn_order_path = f'{save_dir}/unlearn_order_{i}.pth'
        unlearn_config_path = f'methods/unlearn_config/{args.dataset}.json'
        os.makedirs(save_dir, exist_ok=True)
        
        def merge_cfg(old_cfg_dict, new_cfg_dict):
            merged_cfg_dict = copy.deepcopy(old_cfg_dict)
            for k in new_cfg_dict:
                merged_cfg_dict[k] = new_cfg_dict[k]
            return merged_cfg_dict
        
        train_cfg = json.load(open(train_config_path, 'r'))
        unlearn_indices = torch.load(open(unlearn_order_path, 'rb'))
        unlearn_cfg = json.load(open(unlearn_config_path, 'r'))[args.unlearn_method]
        unlearn_cfg = merge_cfg(train_cfg, new_cfg_dict=unlearn_cfg)

        unlearn_cfg['init_state_dict'] = init_state_dict
        unlearn_cfg['llama'] = False
        unlearn_cfg = argparse.Namespace(**unlearn_cfg)
          
        training_args = TrainingArguments(output_dir=save_dir, **train_cfg)
        model.to(DEVICE)
        trainer = Trainer(model=model, args=training_args, tokenizer=tokenizer)
        
        def get_samples(dataset, indices):
            subset = dataset.select(indices)
            return subset
        
        unlearn_round = 1
        all_logs = []
        tot_df_indices = []
        dr_indices = list(range(len(train_set)))
        while len(unlearn_indices) > 0:
            df_indices = unlearn_indices[:args.unlearn_batch_size]
            tot_df_indices.extend(df_indices)
            unlearn_indices = unlearn_indices[len(df_indices):]
            dr_indices = list(set(dr_indices).difference(df_indices))
            df = get_samples(train_set, df_indices)
            dr = get_samples(train_set, dr_indices)
            df_all = get_samples(train_set, tot_df_indices)
            # import pdb; pdb.set_trace()
            print('unlearn_round:', unlearn_round)
            print('df_all_size:', len(df_all))
            print('df_size:', len(df))
            print('dr_size:', len(dr))

            start_time = time.time()
            if args.unlearn_method == 'retraining':
                methods.retrain(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'gd':
                methods.gd(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'sgd':
                methods.sgd(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'negate_loss':
                methods.negate_loss(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'random_label':
                methods.random_label(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'fisher_newton':
                methods.fisher_newton(df, df, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'cr_newton':
                alpha = methods.cr_newton(df, dr, model, unlearn_cfg)
            elif args.unlearn_method == 'scr_newton':
                methods.scr_newton(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'pinv_newton':
                methods.pinv_newton(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'damp_newton':
                methods.damp_newton(df, dr, model, unlearn_cfg, trainer=trainer)
            elif args.unlearn_method == 'lbfgs':
                methods.lbfgs(df, dr, model, unlearn_cfg)
            elif args.unlearn_method == 'ntk':
                methods.ntk(df, dr, model, unlearn_cfg)
            else: raise Exception

            tot_time = time.time() - start_time
            with torch.no_grad():
                results = utils.evaluate(df, df_all, dr, test_set, model, trainer=trainer)
            results['unlearn_round'] = unlearn_round
            results['time'] = tot_time
            if args.unlearn_method == 'cr_newton':
                results['alpha'] = alpha
            print(results, end='\n')
            all_logs.append(results)
            with open(f'{save_dir}/{args.unlearn_method}_results_{i}.json', 'w') as f:
                json.dump(all_logs, f, indent=2)
            unlearn_round += 1

            ckpt = {'state_dict': get_trainable_state_dict(model), 'cfg': vars(unlearn_cfg), 'pytorch-lightning_version': "2.2.1"}
            torch.save(ckpt, f'{save_dir}/{args.unlearn_method}_{i}.pth')
            print(f'Unlearned models are saved to {save_dir}')
        