#This programming is for finding the PAC-Bayesian alignment without considering the rho and kl_penalty. Note, the different proportion of data for prior training is conducted. 

import torch
from pbb.utils import runexp
import numpy as np
import argparse

parse= argparse.ArgumentParser(description="haha")
parse.add_argument('--number_for_prior',type=float,default=0.2,help='number_for_prior')
parse.add_argument('--number_for_posterior',type=int,default=125,help='number_for_posterior')
args=parse.parse_args()


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
DEVICE = torch.device("cpu")
BATCH_SIZE = 250
TRAIN_EPOCHS = 10  #不一样
DELTA = 0.025  
DELTA_TEST = 0.01 
PRIOR = 'learnt' 

SIGMAPRIOR = 0.03
PMIN = 1e-5 
KL_PENALTY = 0.1 
LEARNING_RATE = 1 #不一样 #0.001
MOMENTUM = 0.95
LEARNING_RATE_PRIOR = 1.9 #不一样 0.005
MOMENTUM_PRIOR = 0.99

# note the number of MC samples used in the paper is 150.000, which usually takes a several hours to compute
MC_SAMPLES = 1000
perc_prior = args.number_for_prior #0.05
prior_epochs = 40 #不一样 20
shot_per_class = args.number_for_posterior #125
# note all of these running examples have different settings!
prior_train,prior_trainls,prior_trainls,align,ll,rr,net0 = runexp('mnist', 'fquad', PRIOR, 'fcn', SIGMAPRIOR, PMIN, LEARNING_RATE, MOMENTUM, LEARNING_RATE_PRIOR, MOMENTUM_PRIOR, delta=DELTA, kl_penalty = KL_PENALTY, prior_epochs = prior_epochs,delta_test=DELTA_TEST, mc_samples=MC_SAMPLES, train_epochs=TRAIN_EPOCHS, device=DEVICE, perc_train=1.0, verbose=True, dropout_prob=0.2, perc_prior=perc_prior,shot_per_class=shot_per_class)
np.save((str(SIGMAPRIOR)+"_"+str(shot_per_class)+"_"+str(perc_prior)+"_fcn_.npy"),[prior_train,prior_trainls,prior_trainls,align,ll,rr,net0])
