#!/usr/bin/env python3

# pyright: reportMissingImports=true, reportUntypedBaseClass=true, reportGeneralTypeIssues=true

import os, time, argparse, copy, pickle, torch
from .load_datasets import get_loaders
from .golatkar import golatkar_precomputation, get_ntk_model, fisher_init, \
                    apply_fisher_noise
from .methods import retrain_lastK, cat_forget_finetune
from .utils import argparse2bool, load_model, print_model_params, save_model, get_forget_retain_loader, print_final_times, get_logger, mkdir
from .models import get_model

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--oldLoad', type=argparse2bool, nargs='?',
                        const=True, default=False,
                        help='Old (golatkar) styled checkpoints (default: False)')
    parser.add_argument('--path-o', required=True,
                        help='Path to original model checkpoint')
    parser.add_argument('--path-r', required=True,
                        help='Path to retrained model checkpoint')
    parser.add_argument('--path-oarg', required=True,
                        help='Path to original model args')
    parser.add_argument('--path-rarg', required=True,
                        help='Path to retrained model args')
    parser.add_argument('--init-checkpoint', required=True,
                        help='Path to init checkpoint for golatkar')

    parser.add_argument('--num-classes', type=int, default=10,
                        help='Number of Classes')
    parser.add_argument('--scheduler', default=None,
                        choices = ['CosineAnnealingWarmRestarts', 'CosineAnnealingLR'],
                        help='Pytorch Scheduler name: (default: The one used for train, in args_re')
    parser.add_argument('--regularization', default=None,
                        choices = ['none', 'cutmix', 'remove'],
                        help='Regularization type (default: None)')

    parser.add_argument('--golatkar', type=argparse2bool, nargs='?',
                        const=True, default=False,
                        help='Whether to use golatkar methods')
    parser.add_argument('--name-go', default='Golatkar')
    
    parser.add_argument('--retrfinal', type=argparse2bool, nargs='?',
                        const=True, default=False,
                        help='Whether to use retrain final K')
    parser.add_argument('--name-rf', default='RetrFinal')
    parser.add_argument('--epochs-rf', type=int, default=62, metavar='N',
                        help='number of epochs to train (default: 62)')
    parser.add_argument('--maxL-rf', type=int, default=3, metavar='UL',
                        help='Layers to retrain upperbound (default: 3)')
    parser.add_argument('--minL-rf', type=int, default=1, metavar='LL',
                        help='Layers to retrain lowerbound (default: 1)')
    parser.add_argument('--stepL-rf', type=int, default=1, metavar='LS',
                        help='Layers to retrain step size (default: 1)')

    parser.add_argument('--finetune', type=argparse2bool, nargs='?',
                        const=True, default=False,
                        help='Whether to use finetune method')    
    parser.add_argument('--name-ft', default='Finetune')


    parser.add_argument('--epochs-ft', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 62)')
    parser.add_argument('--maxlr-ft', type=float, default=0.1, metavar='LR',
                        help='max learning rate for SGDR (default: 0.1)')
    parser.add_argument('--minlr-ft', type=float, default=0.005, metavar='LR',
                        help='min learning rate for SGDR (default: 0.005')

    parser.add_argument('--finetune-final', type=argparse2bool, nargs='?',
                        const=True, default=False,
                        help='Whether to use finetune final K')
    parser.add_argument('--name-ftF', default='FinetuneFinal')
    parser.add_argument('--epochs-ftF', type=int, default=62, metavar='N',
                        help='number of epochs to train (default: 62)')
    parser.add_argument('--maxL-ftF', type=int, default=3, metavar='UL',
                        help='Layers to finetune upperbound (default: 3)')
    parser.add_argument('--minL-ftF', type=int, default=1, metavar='LL',
                        help='Layers to finetune lowerbound (default: 1)')
    parser.add_argument('--stepL-ftF', type=int, default=1, metavar='LS',
                        help='Layers to finetune step size (default: 1)')

    args = parser.parse_args()
    if args.regularization == 'none':
        args.regularization = None
    return args

def update_args(args):
    d = vars(args)

    if 'scheduler' not in d.keys():
        d['scheduler'] = 'CosineAnnealingWarmRestarts'

    return args


def get_unlearn_model(args, model_og, model_re, loader_re, loader_f, loader_v, loader_te, device, child_dir):
    filters = 1.
    prefixpath = child_dir
    args.name_go = "Golatkar"
    logname = 'unlearn'
    orig_name = 'log'
    # Get jacobian hessian stuff computed and stored
    if args.unlearn_model == "golatkar":
        starttime = time.monotonic() 
        os.makedirs(f"{prefixpath}/{args.name_go}/", exist_ok = True)
        modpath_fisher = f"{prefixpath}/{args.name_go}/fisher.pt"
        #Fisher 
        starttime = time.monotonic()
        modelf = fisher_init(device, loader_re.dataset, model_og)
        modelf0 = fisher_init(device, loader_re.dataset, model_re)
        apply_fisher_noise(args.seed, args.num_classes, args.num_to_forget, modelf, modelf0)
        compute_f = time.monotonic() - starttime
        save_model(modpath_fisher, modelf)

    #Finetune
    if args.unlearn_model == "finetune":
        args_ft = argparse.Namespace(device = device, log_dir = 'unlearnedModel', exp_name = 'log', name = 'ft_baseline',
                            momentum=args.momentum, disable_bn = args.disable_bn,
                            regularization = args.regularization, scheduler=args.scheduler,
                            cutmix_prob = args.cutmix_prob, cutmix_alpha = args.cutmix_alpha,
                            clip = args.clip, weight_decay = args.wd,
                            maxlr = args.maxlr_ft, minlr = args.minlr_ft)
        modpath_ft = f"{prefixpath}/{args.name_ft}/{args_ft.name}"
        mkdir(f'{prefixpath}/{args.name_ft}')
        logger_ft = get_logger(folder=f'{prefixpath}/{args.name_ft}/', logname=logname)
        model_ft = copy.deepcopy(model_og)
        model_ft, compute_ft = cat_forget_finetune(args_ft, model_ft, args.model, loader_re, loader_v,
                                        loader_te, args.epochs_ft, logger_ft, modpath_ft)
        save_model(f'{modpath_ft}.pt', model_ft)
        logger_ft.info(f'Finetune Time: {compute_ft}')
        logger_ft.info(f'Finetune Path: {modpath_ft}')

    #Retrain last K layers
    if args.unlearn_model == "retrfinal":
        args_rf = argparse.Namespace(device = device, log_dir = 'unlearnedModel', exp_name = 'log',name = 'retrfinal',
                                    momentum=args.momentum, disable_bn = args.disable_bn, 
                                    regularization = args.regularization, scheduler=args.scheduler,
                                    cutmix_prob = args.cutmix_prob, cutmix_alpha = args.cutmix_alpha,
                                    clip = args.clip, weight_decay = args.wd,
                                    maxlr = args.maxlr, minlr = args.minlr) 
        name_prefix = "RetrFinal_"
        modprefix_rf = f"{prefixpath}/{args.name_rf}/{name_prefix}"
        mkdir(f'{prefixpath}/{args.name_rf}')
        logger_rf = get_logger(folder=f'{prefixpath}/{args.name_rf}/', logname=logname)
        retrfinal_models, retrfinal_times = [], []
        for k in range(args.minL_rf, args.maxL_rf+1, args.stepL_rf):
            model_rf = copy.deepcopy(model_og)
            args_rf.name = f'{name_prefix}{k}'
            model_rf, compute_rf = retrain_lastK(k, args_rf, model_rf, args.model, loader_re,
                                    loader_v, loader_te, args.epochs_rf, logger_rf, f'{modprefix_rf}{k}', None)
            save_model(f"{modprefix_rf}{k}.pt", model_rf)
            logger_rf.info(f'Retrained last {k} layers. Time: {compute_rf}')
            retrfinal_models.append(model_rf)
        logger_rf.info(f'RetrainFinal Prefix: {modprefix_rf}')

    #Finetune last K layers
    if args.unlearn_model == "ftfinal":
        args_ftF = argparse.Namespace(device = device, log_dir = 'unlearnedModel', exp_name = 'log', name = 'ft_final',
                                momentum=args.momentum, disable_bn = args.disable_bn, 
                                regularization = args.regularization, scheduler=args.scheduler,
                                cutmix_prob = args.cutmix_prob, cutmix_alpha = args.cutmix_alpha,
                                clip = args.clip, weight_decay = args.wd,
                                maxlr = args.maxlr, minlr = args.minlr)   
        name_prefix = "FTfinal_"
        modprefix_ftF = f"{prefixpath}/{args.name_ftF}/{name_prefix}"
        mkdir(f'{prefixpath}/{args.name_ftF}')
        logger_ftF = get_logger(folder=f'{prefixpath}/{args.name_ftF}/', logname=logname)
        FTfinal_models, FTfinal_times = [], []
        for k in range(args.minL_ftF, args.maxL_ftF+1, args.stepL_ftF):
            model_ftF = copy.deepcopy(model_og)
            args_ftF.name = f'{name_prefix}{k}'
            model_ftF, compute_ftF = cat_forget_finetune(args_ftF, model_ftF, args.model, loader_re, 
                                    loader_v, loader_te, args.epochs_ftF, logger_ftF, f'{modprefix_ftF}{k}', k)
            save_model(f"{modprefix_ftF}{k}.pt", model_ftF)
            logger_ftF.info(f'Finetuned last {k} layers. Time: {compute_ftF}')
            FTfinal_models.append(model_ftF)
        logger_ftF.info(f'FinetuneFinal Prefix: {modprefix_ftF}')