import sys
import pathlib
BASELINE_PATH = pathlib.Path(__file__).parent.resolve()
sys.path.append(BASELINE_PATH)

from baselines import it_unlearn, tv_unlearn, finetune
from baselines.prism_unlearn import prism_unlearn

import argparse
from os.path import basename, dirname, join as pathjoin

def main():
    args = get_args()
    print(args.out_dir)
    if args.algo == 'kn':
        raise NotImplementedError()
    elif args.algo == 'tv':
        ft_model_dir = pathjoin(dirname(args.out_dir), basename(args.out_dir) + "_ft")
        finetune(
            args.model_dir, args.data_file, ft_model_dir,
            epochs=args.epochs,
            per_device_batch_size=args.per_device_batch_size,
            learning_rate=args.lr,
            max_len=args.max_len,
            tokenizer_dir=args.tokenizer_dir
        )
        tv_unlearn(
            args.model_dir, args.out_dir,
            some_pt_model_dir=args.model_dir,
            some_ft_model_dir=ft_model_dir,
            alpha=args.alpha
        )
    elif "prism" in args.algo:
        prism_unlearn(
            args.model_dir, args.data_file, args.out_dir,
            retain_data_file=args.retain_data_file,
            loss_type=args.algo,
            per_device_batch_size=args.per_device_batch_size,
            epochs=args.epochs,
            learning_rate=args.lr,
            max_len=args.max_len,
            tokenizer_dir=args.tokenizer_dir,
            resume_from_checkpoint=args.resume_from_checkpoint,
            beta=args.beta,
            coeff=args.coeff,
            npo_coeff=args.npo_coeff,
            gamma=args.gamma,
            sam_rho=args.sam_rho,
            pretrained_probe_path=args.pretrained_probe_path,
            adv_gamma=args.adv_gamma,
            select_layer=args.select_layer
        )
    else:
        it_unlearn(
            args.model_dir, args.data_file, args.out_dir,
            retain_data_file=args.retain_data_file,
            loss_type=args.algo,
            per_device_batch_size=args.per_device_batch_size,
            epochs=args.epochs,
            learning_rate=args.lr,
            max_len=args.max_len,
            tokenizer_dir=args.tokenizer_dir,
            resume_from_checkpoint=args.resume_from_checkpoint,
            beta=args.beta,
            coeff=args.coeff,
            npo_coeff=args.npo_coeff,
            gamma=args.gamma,
        )
    return

def get_args():
    parser = argparse.ArgumentParser(description="Unlearning baselines")
    parser.add_argument('--algo', type=str)
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--tokenizer_dir', type=str, default=None)
    parser.add_argument('--data_file', type=str)
    parser.add_argument('--out_dir', type=str)
    parser.add_argument('--max_len', type=int, default=4096)
    parser.add_argument('--resume_from_checkpoint', action='store_true')
    parser.add_argument('--per_device_batch_size', type=int, default=2)
    parser.add_argument('--retain_data_file', type=str, default=None)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--alpha', type=float, default=1.0)
    parser.add_argument('--beta', type=float, default=0.1)
    parser.add_argument('--coeff', type=float, default=0.1)
    parser.add_argument('--npo_coeff', type=float, default=0.1)
    parser.add_argument('--gamma', type=float, default=0.0)
    parser.add_argument('--sam_rho', type=float, default=0.01)
    parser.add_argument('--pretrained_probe_path', type=str, default=None)
    parser.add_argument('--adv_gamma', type=float, default=0.005)
    parser.add_argument('--select_layer', type=int, default=32)
    args = parser.parse_args()
    if args.algo == 'gd':
        assert args.retain_data_file is not None, "Gradient difference selected. Retain set required."
    if args.resume_from_checkpoint:
        assert args.algo not in {'tv'}, "Cannot resume from checkpoint if the method is task vector."
    return args

if __name__ == '__main__':
    main()