from argparse import ArgumentParser
from xmeta.maml.maml import save_explainer
import pickle
import os
import sys
import torch

loss = torch.nn.CrossEntropyLoss(reduction='mean')


def main(args):
    save_dir = args.dir
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))
    
    pkl_prefix = ''
    if args.num_hessian_elements is not None:
        pkl_prefix = f'nh{args.num_hessian_elements}' + pkl_prefix
    if args.exact_inverse:
        pkl_prefix = 'einv_' + pkl_prefix
    else:
        if args.no_reg:
            pkl_prefix = 'noreg_' + pkl_prefix
        pkl_prefix = f'order{args.order}_acc{args.acc}_' + pkl_prefix

    # load explainer
    with open(args.explainer, 'rb') as f:
        explainer = pickle.load(f)
 
    if (args.num_hessian_elements is not None) and\
       explainer.src_test_hessian.num_elements > args.num_hessian_elements:
        print(f'numuber of hessian elemetns '
              f'{explainer.src_test_hessian.num_elements} > {args.num_hessian_elements}')
        n_del = explainer.src_test_hessian.num_elements - args.num_hessian_elements
        scale = explainer.src_test_hessian.num_elements /\
            (explainer.src_test_hessian.num_elements - n_del)
        explainer.src_test_hessian.delete_small_elements(n_del)
        explainer.src_test_hessian.scale(scale)
        print(f'deleted {n_del} elements')
    
    if args.exact_inverse:
        explainer.set_src_param_matrix(taylor_series=(not args.exact_inverse))
    else:
        explainer.set_src_param_matrix(pseudo_inv=(not args.no_reg),
                                       order=args.order,
                                       acc=args.acc,
                                       )
    if args.discard_intermediate:
        explainer.discard_intermediate()
    
    save_explainer(explainer, prefix=pkl_prefix)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dir', type=str, default=None)
    parser.add_argument('--explainer', type=str, default=None)
    parser.add_argument('--order', type=int, default=100)
    parser.add_argument('--acc', type=float, default=0.001)
    parser.add_argument('--no-reg', action='store_true')
    parser.add_argument('--exact-inverse', action='store_true')
    parser.add_argument('--num-hessian-elements', type=int, default=None)
    parser.add_argument('--discard-intermediate', action='store_true')

    args = parser.parse_args()
    main(args)
