"""
usage: python train.py --dataset cora --Pi 128  --split public
"""
import argparse

import dkm
import json
import dkm.onesolve as OS
import torch as t
from dataset import graph as GD
import torch.nn.functional as F
import time
import traceback
import dkm.train_util as TU
from warnings import warn
import dataset.meta as DM
import analysis_util as AU
import dkm.util as U

def train_model(**kwargs):
    """
    don't want training failure to take down other calls
    """
    try:
        _train_model(**kwargs)
    except Exception:
        traceback.print_exc()
        print("DKMTRAINFAIL")

DS_CACHE = dict()

def _train_model(**kwargs):
    if not isinstance(kwargs, dict): kwargs.__dict__
    args = kwargs
    print("DKMTRAINRUN")
    print("begin args ========")
    print(json.dumps(args, indent=2))
    print("end args ========")
    device = args['device']
    typ = {'float32': t.float32, 'float64': t.float64}[args['dtype']]
    if args['dtype'] == 'float32':
        # print("using float32 and tf32")
        ## disable tf32, so that results on A100 are the same as results on 3090
        t.backends.cuda.matmul.allow_tf32 = False
        t.backends.cudnn.allow_tf32 = False
    t.set_default_dtype(typ)
    TU.set_seed(args['seed'])


    # option to edit homophily ratio of the graph
    edit_hr = args.get('edit_hr', -1)
    if edit_hr < 0: edit_hr = None
    else: assert edit_hr <= 1.0
    normalize_features = args['normalize_features'] if 'normalize_features' in args else False
    ds = GD.get_dataset(args['dataset'], split=args['split'], edit_hr=edit_hr, normalize_features=normalize_features)

    if args['Pi'] == -1: Pi = ds.train_mask.sum()
    else: Pi = args['Pi']

    N = ds.num_features; num_classes = ds.num_classes
    is_multi_graph = ds.is_multi_graph; is_single_graph = not is_multi_graph
    assert is_single_graph, "single graph ==> node task"

    # prep data for input
    dataset_size = None
    do_val = ds.val_mask is not None
    Xt = ds.X.to(device, dtype=typ)
    y = ds.y.to(device)

    if args['dataset'] not in DS_CACHE:
        Xt = ds.X.to(device, dtype=typ)
        if args['scale_inputs']:
            tt_diag = (Xt * Xt).sum(-1)
            tt_diag[tt_diag <= 0.] = 1.
            Xt = Xt * tt_diag.rsqrt().unsqueeze(-1) # make sure Xt is on a sensible scale for the linear kernel!
        y = ds.y.to(device)
        adj_sp = ds.adj_sp.to(device)
        DS_CACHE[args['dataset']] = (Xt, y, adj_sp)
    else:
        (Xt, y, adj_sp) = DS_CACHE[args['dataset']]


    ytrue = dict(
        train=y[ds.train_mask],
        test=y[ds.test_mask],
        all=y
    )
    dataset_size = len(ds.train_mask)
    assert dataset_size == len(ds.test_mask)
    if do_val:
        ytrue['val'] = y[ds.val_mask]
        assert dataset_size == len(ds.val_mask)

    ## select some inducing points
    non_zero = Xt.sum(dim=1).abs() > 1e-10
    is_train = ds.train_mask.to(device)
    _ind_mask = t.logical_and(non_zero, is_train)
    indices = t.arange(Xt.size(0)).to(device)
    ind_ixs = indices[_ind_mask][:Pi]

    ## it's super important that the inducing points are not in the test set
    ## so do a sanity check!
    test_ixs = t.arange(len(ds.test_mask))[ds.test_mask.cpu()]
    assert set(ind_ixs.cpu().numpy()).isdisjoint(set(test_ixs.cpu().numpy())), "inducing points are definitely not test"
    if do_val:
        val_ixs = t.arange(len(ds.val_mask))[ds.val_mask.cpu()]
        assert set(ind_ixs.cpu().numpy()).isdisjoint(set(val_ixs.cpu().numpy())), "inducing points are definitely not val"

    Xi = Xt[ind_ixs,:]
    yi = ytrue['all'][ind_ixs]
    init_mu = 2*F.one_hot(yi, num_classes).to(dtype=typ) - 1 ## initialize inducing labels


    if len(ind_ixs) < Pi:
        warn("too many inducing points for the training data - creating pseudo inducing points!")
        _Pi = Pi - len(ind_ixs)

        ## try to roughly match distribution of existing inducing points (this is approximate)
        Xtrain = Xt[ds.train_mask,:]
        _Xi = t.rand(_Pi, N).to( device=device, dtype=typ) * (Xtrain.max() - Xtrain.min())
        _yi = 2*t.rand(_Pi, num_classes, device=device, dtype=typ) - 1

        Xi = t.cat((_Xi, Xi), dim=0)
        init_mu = t.cat((_yi, init_mu), dim=0)

        ## pick some random ind ixs
        ## these are only used if we are doing 'fixed-full' mixup, where we need
        ## to have some adjacency information for inducing points.
        ## for this, we sample from the train indices
        ## though in principle it shouldn't even be a problem to sample from test!
        sample_ixs = t.randint(0, len(ind_ixs), (_Pi,)).to(device)
        _ind_ixs = ind_ixs[sample_ixs]
        ind_ixs = t.cat((_ind_ixs, ind_ixs), dim=0)

    if args['mixup_scheme'] == 'fixed-full':
        ind_ixs = t.arange(len(ds.train_mask), device=device)[ind_ixs]
        adj_sp = GD.prepend_duplicate_rows(adj_sp, ind_ixs)

    kernel = args['kernel'] if 'kernel' in args else 'relu'
    adj_sp = adj_sp.to(device, dtype=typ)

    do_checkpointing = args['dataset'] in DM.bigger_datasets
    do_learn_Xi = args['learn_Xi'] == 'yes'
    feat_to_gram_params = dict(Xi=Xi, do_learn_Xi=do_learn_Xi)
    mc_samples = ds.mc_samples if ds.mc_samples is not None else 500
    output_params = dict(mc_samples=mc_samples, init_mu=init_mu,
                         learn_mu=True,
                         chunk_size=ds.chunk_size)
    gram_params = dict()
    mixup_params = dict(mode=args['mixup_scheme'], lmbda=args.get('adj_lambda', 0.))

    if args['model'] == 'res':
        model = OS.ResGraphDKM(Pi=Pi,
                        Nin=N,
                        Nout=num_classes,
                        num_layers=args['num_layers'],
                        dof=args['dof'],
                        feat_to_gram_params=feat_to_gram_params,
                        kernel=kernel,
                        output_params=output_params,
                        center=args['center'],
                        center_learned=args['center_learned'],
                        gram_params=gram_params,
                        do_checkpointing=do_checkpointing,
                        mixup_params=mixup_params,
                        ).to(device, dtype=typ)
    elif args['model'] == 'kipf' or args['model'] == 'kipfres':
        residual = args['model'] == 'kipfres'
        model = OS.KipfGraphDKM(Pi=Pi,
                              Nin=N,
                              Nout=num_classes,
                              num_layers=args['num_layers'],
                              dof=args['dof'],
                              kernel=kernel,
                              feat_to_gram_params=feat_to_gram_params,
                              output_params=output_params,
                              center=args['center'],
                              center_learned=args['center_learned'],
                              gram_params=gram_params,
                              do_checkpointing=do_checkpointing,
                              mixup_params=mixup_params,
                              residual=residual
                             ).to(device, dtype=typ)
    print(model)

    total_num_params = sum(p.numel() for p in model.parameters())
    print(f"total_num_params = {total_num_params}")

    num_epochs = args['num_epochs']
    opt, scheduler = TU.mk_opt_and_scheduler(model, num_epochs,
                                            warmup=args['num_epochs'] // 4,
                                            schedule='cosine',
                                            init_lr=1e-4,
                                            max_lr=1e-2,
                                            min_lr=1e-5,
                                            weight_decay=args.get('weight_decay',0.))

    ## training loop
    ## fullrank, single graph and multi-graph all a bit different, so do them separately
    print("TRAINLOOPBEGIN")
    timestart = time.time()
    # store stats
    epoch_stats = None
    # we do a forward pass to initialize gram layers
    model.eval()
    lls, mean_logprobs, reg = model(Xt, adj_sp, labels=y)
    print("========")

    def eval_classification(lls, mean_logprobs):
        train_ll = lls[ds.train_mask].mean()
        if do_val:
            val_ll = lls[ds.val_mask].mean()
        else:
            val_ll = None
        test_ll = lls[ds.test_mask].mean()

        ypred = mean_logprobs.argmax(dim=-1)
        ytrain = ypred[ds.train_mask]
        if do_val:
            yval = ypred[ds.val_mask]
        else:
            yval = None
        ytest = ypred[ds.test_mask]
        return dict(train_ll=train_ll,
                    val_ll=val_ll,
                    test_ll=test_ll,
                    ytrain=ytrain,
                    yval=yval,
                    ytest=ytest)

    def zero_opt():
        """
        more memory efficient than [opt.zero_grad]:
        https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-parameter-grad-none-instead-of-model-zero-grad-or-optimizer-zero-grad
        """
        for p in opt.param_groups[0]['params']:
            p.grad = None
    minibatch_size = ds.minibatch_size

    if args.get('no_mixup', False):
        adj_sp = None

    epoch_i = 0
    finished_training = False
    while not finished_training:
        epoch_i += 1
        model.train()
        epochstarttime = time.time()

        if minibatch_size is None or minibatch_size <= 0: # dont do minibatching
            zero_opt()
            lls, mean_logprobs, reg = model(Xt, adj_sp, labels=y)
            reg = reg / dataset_size # normalize reg so that it is 'reg per dataset'
            obj = lls[ds.train_mask].mean() + reg
            (-obj).backward()
            opt.step()
            ev = eval_classification(lls, mean_logprobs)
        else:
            Ptrain = ds.train_mask.sum()
            perm = t.randperm(Ptrain)
            train_splits = t.split(t.argwhere(ds.train_mask)[perm].squeeze(), minibatch_size)
            for train_split in train_splits:
                zero_opt()
                train_split = train_split.to(device, dtype=t.long)
                lls, mean_logprobs, reg = model(Xt, adj_sp, labels=y, ixs=train_split)
                reg = reg / dataset_size # normalize reg so that it is 'reg per dataset'
                obj = lls.mean() + reg
                (-obj).backward()
                opt.step()

            model.eval()
            with t.no_grad():
                lls, mean_logprobs, reg = model(Xt, adj_sp, labels=y)
                reg = reg / dataset_size
                ev = eval_classification(lls, mean_logprobs)

        exceeded_max_n_epochs = epoch_i >= num_epochs
        finished_training = exceeded_max_n_epochs

        if finished_training: # last epoch!
            model.eval()
            lls, mean_logprobs, reg = model(Xt, adj_sp, labels=y)
            reg = reg / dataset_size # normalize reg so that it is 'reg per dataset'
            ev = eval_classification(lls, mean_logprobs)

        max_gb = t.cuda.max_memory_allocated(device=device) / 10**9

        stats = dict()
        for mode in ['train'] + ([] if not do_val else ['val']) + ['test']:
            obj = ev[f'{mode}_ll'] + reg
            stats[f'{mode}_obj'] = obj.item()
            stats[f'{mode}_ll'] = ev[f'{mode}_ll'].item()
            if ds.is_classification_task:
                stats[f'{mode}_acc'] = (ytrue[mode] == ev[f'y{mode}']).mean(dtype=typ).item()
            else:
                stats[f'{mode}_mse'] = (ytrue[mode] - ev[f'y{mode}']).square().mean(dtype=typ).item()

        to_print_stdout = dict(epoch=epoch_i) | stats
        to_print_stdout = to_print_stdout | dict(lr=opt.param_groups[0]['lr'],
                                                    mem=max_gb,
                                                    t=time.time() - epochstarttime)
        buf = TU.epoch_metrics_buf(to_print_stdout)
        scheduler.step()
        epoch_stats = dict(epoch_i=epoch_i) | stats
        print(buf.getvalue())


    print("========")
    timeend = time.time()
    elapsed = timeend - timestart
    print("ELAPSED = ", elapsed)
    print("========")
    print(TU.epoch_metrics_buf(epoch_stats).getvalue())
    print("========")
    print("TRAINLOOPEND")

    return_kernel = args.get('return_kernel', 'no') != 'no'
    if return_kernel:
        try:
            sample_size = int(args['return_kernel'].split('-')[-1])
        except:
            print("return_kernel must be 'no', 'yes' or 'yes-<sample_size>'")
            raise

        model.eval()
        if ds.is_classification_task:
            sortixs = t.argsort(y)
        else:
            sortixs = t.arange(len(y))
        G = model(Xt, adj_sp=adj_sp, returns='final-kernel')
        Ft = G.ft[sortixs,:].clone().detach().cpu() ## move to cpu to avoid large memory on gpu
        K = (Ft @ Ft.T)
        y = ytrue['all'][sortixs].clone().detach().cpu()



        yoh = F.one_hot(y).float()
        yyT_kernel = yoh @ yoh.T
        cka_alignment = U.cka(K, yyT_kernel)

        """ sample random vertices too expensive to save them all!"""
        import numpy as np
        nvertices = K.shape[0]
        rng = np.random.default_rng([ord(x) for x in args['dataset']])# fix random seed for each dataset
        indices = rng.integers(0, nvertices, size=sample_size)
        indices = np.sort(indices)

        K = K[indices,:][:,indices]

        _kernels = dict(K=K, y=y,
                        cka=cka_alignment.item())
        epoch_stats['kernels'] = _kernels
    return epoch_stats

"""
hyperparameter search utils below
"""

def _search(default_params, grid_search_params, seeds_splits, fname):
    results = AU.maybe_read_pkl_file(fname)
    for ps in grid_search_params:
        params = default_params | ps
        hash_params = AU.hash_dict(params)

        statss = [] if hash_params not in results else results[hash_params]['metrics']
        def already_trained(split, seed): return any(x['seed'] == seed and x['split'] == split for x in statss)
        for (seed, split) in seeds_splits:
            if already_trained(split, seed):
                print(f"already trained with params/split/seed: -/{split}/{seed}")
                continue

            params['seed'] = seed
            params['split'] = split
            try:
                stats = _train_model(**params)
            except KeyboardInterrupt:
                print("Exiting")
                return
            except:
                import traceback
                traceback.print_exc()
                warn(f"failed to train with params = {params}")
            stats['seed'] = seed
            stats['split'] = split
            statss.append(stats)
            del params['split']
            del params['seed']
        results[hash_params] = dict(metrics=statss, params=params)
        AU.write_pkl_file(results, fname)

_default_search_params = dict(
    gram_sample_factor=0.,
    normalize_features=False,
    dtype='float32',
    num_layers=2,
    device='cuda',
    model = 'kipf',
    max_lr = 1e-2,
    kernel='relu',
    scheduler='cosine',
    center_learned=False,
    center='id',
    dof='inf',
    Pi=100,
    learn_Xi='yes',
    adj_lambda=0.,
    scale_inputs=True,
)
def grid_search_kipf(**kwargs):
    """performs grid search over hyperparameters"""
    args = kwargs
    fname = f"hyperparam_results/kipf_results/{args['dataset']}.pkl"

    ds_name = args['dataset']
    num_epochs = 300 if ds_name not in DM.bigger_datasets else 200 # we do minibatches for bigger datasets
                                                                   # so it is not unreasonable to decrease num of epochs

    grid_search_params = [
        {
            'dataset': ds_name,
            'num_epochs': num_epochs,
            'mixup_scheme': mixup_scheme,
            'dof': dof,
        }
        for mixup_scheme in ['fixed-indep', 'fixed-full']
        for dof in ['inf', 1e0, 1e-1, 1e1, 1e3, 1e-2, 1e2]
    ]

    ## define # of splits etc.
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(10)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(3)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for split in range(10) for seed in range(3)]

    print("number of combos to test", len(grid_search_params))
    _search(_default_search_params, grid_search_params, seeds_splits, fname)

def arch_grid_search(**kwargs):
    args = kwargs
    fname = f"hyperparam_results/arch_results/{args['dataset']}.pkl"
    ds_name = args['dataset']

    models = ['res', 'kipf', 'kipfres']
    adj_lambdas = [0., 0.1, 0.5]

    num_epochs = 200 if ds_name not in DM.bigger_datasets else 150 # we do minibatches for bigger datasets
                                                                  # so it is not unreasonable to decrease num of epochs
    dkm_default_params = _default_search_params | DM.best_kipf[ds_name] | dict(dataset=ds_name, num_epochs=num_epochs)
    nngp_default_params = _default_search_params | DM.best_nngp_kipf[ds_name] | dict(dataset=ds_name, num_epochs=num_epochs)

    grid_search_params = [
        {'model': model,
         'adj_lambda': adj_lambda,
         }
        for model in models
        for adj_lambda in (adj_lambdas if model in ['kipfres', 'kipf'] else [0.])
    ]

    ## define # of splits etc.
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(100, 110)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(100, 102)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for seed in [100, 101] for split in range(0, 10)]

    print("number of combos to test", len(grid_search_params))
    _search(dkm_default_params, grid_search_params, seeds_splits, fname)
    _search(nngp_default_params, grid_search_params, seeds_splits, fname)

def center_search(**kwargs):
    args = kwargs
    fname = f"hyperparam_results/center_results/{args['dataset']}.pkl"

    ds_name = args['dataset']
    num_epochs = 200 if ds_name not in DM.bigger_datasets else 150 # we do minibatches for bigger datasets
                                                                  # so it is not unreasonable to decrease num of epochs
    dkm_default_params = _default_search_params | DM.best_arch[ds_name] \
                      | DM.best_kipf[ds_name] \
                      | dict(num_epochs=num_epochs, dataset=ds_name)
    nngp_default_params = _default_search_params \
                      | DM.best_nngp_kipf[ds_name] \
                      | DM.best_nngp_arch[ds_name] \
                      | dict(num_epochs=num_epochs, dataset=ds_name)
    grid_search_params = [
        {
         'center': center,
         'center_learned': center_learned,
         }
        for center in ['id', 'batch']
        for center_learned in [True, False]
    ]
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(200, 210)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(200, 202)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for seed in [200, 201] for split in range(0, 10)]
    print("number of combos to test", len(grid_search_params))
    _search(dkm_default_params, grid_search_params, seeds_splits, fname)
    _search(nngp_default_params, grid_search_params, seeds_splits, fname)


def Pi_search(**kwargs):
    args = kwargs
    fname = f"hyperparam_results/Pi_results/{args['dataset']}.pkl"
    Pis = [50, 100, 200, 300, 400]
    ds_name = args['dataset']
    num_epochs = 200 if ds_name not in DM.bigger_datasets else 150 # we do minibatches for bigger datasets
                                                                   # so it is not unreasonable to decrease num of epochs
    dkm_default_params = _default_search_params | DM.best_arch[ds_name] \
                      | DM.best_kipf[ds_name] \
                      | DM.best_center[ds_name] \
                      | dict(num_epochs=num_epochs, dataset=ds_name)

    nngp_default_params = _default_search_params | DM.best_nngp_arch[ds_name] \
                      | DM.best_nngp_kipf[ds_name] \
                      | DM.best_nngp_center[ds_name] \
                      | dict(num_epochs=num_epochs, dataset=ds_name)
    grid_search_params = [
        {
         'Pi': Pi,
         }
        for Pi in Pis
    ]
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(300, 310)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(300, 302)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for seed in [300, 301] for split in range(0, 10)]
    print("number of combos to test", len(grid_search_params))
    _search(dkm_default_params, grid_search_params, seeds_splits, fname)
    _search(nngp_default_params, grid_search_params, seeds_splits, fname)

def final_acc(**kwargs):
    args = kwargs
    fname = f"hyperparam_results/final_results/{args['dataset']}.pkl"
    ds_name = args['dataset']
    num_epochs = 200 if ds_name not in DM.bigger_datasets else 150 # we do minibatches for bigger datasets
                                                                   # so it is not unreasonable to decrease num of epochs
    dkm_default_params = _default_search_params | DM.best_arch[ds_name] \
                      | DM.best_kipf[ds_name] \
                      | DM.best_center[ds_name] \
                      | DM.best_Pi[ds_name] \
                      | dict(dataset=ds_name, num_epochs=num_epochs)
    nngp_default_params = _default_search_params | DM.best_nngp_arch[ds_name] \
                      | DM.best_nngp_kipf[ds_name] \
                      | DM.best_nngp_center[ds_name] \
                      | DM.best_nngp_Pi[ds_name] \
                      | dict(dataset=ds_name, num_epochs=num_epochs)
    grid_search_params = [{}]
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(1000, 1010)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(1000, 1005)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for seed in [1000, 1004] for split in range(0, 10)]
    print("number of combos to test", len(grid_search_params))
    _search(dkm_default_params, grid_search_params, seeds_splits, fname)
    _search(nngp_default_params, grid_search_params, seeds_splits, fname)

def get_shaped_kernels(**kwargs):
    args = kwargs
    fname = f"shaped/{args['dataset']}.pkl"

    ds_name = args['dataset']
    num_epochs = 300
    arch_params = {'model': 'kipf', 'adj_lambda': 0.3}
    default_params = _default_search_params | dict(
        dataset=ds_name,
        num_epochs = num_epochs,
        return_kernel='yes-400',
    ) | arch_params


    grid_search_params = [
        {
            'mixup_scheme': DM.best_kipf[ds_name]['mixup_scheme'],
            'dof': dof,
        }
        for dof in [0.,'inf',1e0, 1e1, 1e3, 1e2]
    ]

    ## define # of splits etc.
    if args['dataset'] == 'cora':
        seeds_splits = [(200, 'public')]
    elif args['dataset'] == 'roman-empire':
        seeds_splits = [(200, f'test-200-0')]

    print("number of combos to test", len(grid_search_params))
    _search(default_params, grid_search_params, seeds_splits, fname)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--split', type=str, default='', help='name or number of dataset split')
    parser.add_argument("--num-epochs", type=int, default=200)
    parser.add_argument('--device', type=str, nargs='?', default='cuda', choices=['cpu', 'cuda'])
    parser.add_argument('--dtype', type=str, nargs='?', default='float32', choices=['float32', 'float64'])
    parser.add_argument("--seed", type=int, default=0)
    # model params
    parser.add_argument("--dof", type=float, default=1.)
    parser.add_argument("--Pi", type=int, default=100)
    parser.add_argument("--num-layers", type=int, default=2)
    parser.add_argument("--model", "--arch", type=str, default='kipf', choices=['res', 'kipf', 'kipfres'])
    parser.add_argument("--kernel", type=str, default='relu', choices=['relu', 'id'])
    parser.add_argument("--center", type=str, default='id', choices=['id', 'batch'])
    parser.add_argument("--center-learned", action='store_true')
    parser.add_argument("--return-kernel", type=str, default='no')
    parser.add_argument("--learn-Xi", type=str, default='yes')
    parser.add_argument("--mixup-scheme", type=str, default='fixed-indep', choices=['none', 'fixed-indep', 'fixed-full'], help='mixup scheme')
    parser.add_argument("--adj-lambda", type=float, default=0.)
    parser.add_argument("--scale-inputs", action='store_true', default=True)
    parser.add_argument("--weight-decay", type=float, default=0.)
    # convenience flags for experiments
    parser.add_argument("--get-shaped", action='store_true')
    parser.add_argument("--arch-grid-search", action='store_true', default=False)
    parser.add_argument("--kipf-grid-search", action='store_true', default=False)
    parser.add_argument("--center-search", action='store_true', default=False)
    parser.add_argument("--Pi-search", action='store_true', default=False)
    parser.add_argument("--final-acc", action='store_true', default=False)

    args = parser.parse_args()
    args_dict = args.__dict__

    def with_keyboard_int(f):
        try: f()
        except KeyboardInterrupt: print("Exiting")

    if args.arch_grid_search:
        with_keyboard_int(lambda: arch_grid_search(**args_dict))
    elif args.get_shaped:
        with_keyboard_int(lambda: get_shaped_kernels(**args_dict))
    elif args.kipf_grid_search:
        with_keyboard_int(lambda: grid_search_kipf(**args_dict))
    elif args.center_search:
        with_keyboard_int(lambda: center_search(**args_dict))
    elif args.Pi_search:
        with_keyboard_int(lambda: Pi_search(**args_dict))
    elif args.final_acc:
        with_keyboard_int(lambda: final_acc(**args_dict))
    else:
        model = train_model(**args_dict)