#cnn_ntk_init_posterior is a brutal-force grid search finding the generalization bound under convolutional neural network with CIFAR10. 
import torch# 
from pbb.utils import runexp#
import argparse #
import numpy as np #

parse= argparse.ArgumentParser(description="haha")
parse.add_argument('--number_for_prior',type=float,default=0.2,help='number_for_prior')
parse.add_argument('--kl_penalty',type=float,default=1,help='kl_penalty')
args=parse.parse_args()

perc_prior = args.number_for_prior 

net = np.load('XXXxxx.npy',allow_pickle=True)
net0 = net[0]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(torch.cuda.is_available()) 
BATCH_SIZE = 250
TRAIN_EPOCHS = 70 
DELTA = 0.025 
DELTA_TEST = 0.01
PRIOR = 'learnt'
PMIN = 1e-5 
MOMENTUM_PRIOR = 0.99
LEARNING_RATE_PRIOR = 25
MC_SAMPLES = 150000 
MOMENTUM = 0.95

SIGMAPRIOR = 0.03########
KL_PENALTY = args.kl_penalty
LEARNING_RATE = 1


# note all of these running examples have different settings!
print(str(perc_prior))
risk_01,ens_err,post_err,stch_err = runexp(net0, 'cifar10', 'fquad', PRIOR, 'cnn', SIGMAPRIOR, PMIN, LEARNING_RATE, MOMENTUM, LEARNING_RATE_PRIOR, MOMENTUM_PRIOR, delta=DELTA, delta_test=DELTA_TEST, mc_samples=MC_SAMPLES, train_epochs=TRAIN_EPOCHS, device=DEVICE, perc_train=1.0, verbose=True, perc_prior=perc_prior, kl_penalty = KL_PENALTY, dropout_prob=0.2,layers=13)
np.save(str(KL_PENALTY)+'_'+str(SIGMAPRIOR)+'_'+str(perc_prior)+"_fcn_.npy",[risk_01,ens_err,post_err,stch_err])

