# -*- coding: utf-8 -*-
"""
Created on Fri Oct  1 00:01:20 2021

@author: chunr
"""
import argparse #
import torch # 
from pbb.utils import runexp #
import numpy as np #
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as td
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from pbb.models import NNet4l,trainNNet_NTK, CNNet4l, ProbNNet4l, ProbCNNet4l, ProbCNNet9l, CNNet9l, CNNet13l, ProbCNNet13l, ProbCNNet15l, CNNet15l, trainNNet, testNNet, Lambda_var, trainPNNet, computeRiskCertificates, testPosteriorMean, testStochastic, testEnsemble
from pbb.bounds import PBBobj
from pbb import data

parse= argparse.ArgumentParser(description="haha")
parse.add_argument('--number_for_prior',type=float,default=0.2,help='number_for_prior')
parse.add_argument('--number_for_posterior',type=int,default=2,help='number_for_posterior')
args=parse.parse_args()

Shot_per_class = args.number_for_posterior 
Perc_prior = args.number_for_prior 

prior_train = []
prior_trainls =[]

net = np.load('XXxxxx.npy',allow_pickle=True)

torch.manual_seed(7)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

loader_kargs = {'num_workers': 1,'pin_memory': True} if torch.cuda.is_available() else {}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #
#DEVICE = torch.device("cpu") 
train, test = data.loaddataset('cifar10')
rho_prior = math.log(math.exp(0.03)-1.0)
train_loader, test_loader, valid_loader, val_bound_one_batch, _, val_bound = data.loadbatches(
            train, test, loader_kargs, 250, prior=True, perc_train=1.0, perc_prior=Perc_prior,shot_per_class = Shot_per_class)
#errornet0 = testNNet(net[0], test_loader, device=DEVICE)



grads_left_x,ll,rr= trainNNet_NTK(net[0], 1, val_bound_one_batch, device=DEVICE, verbose=True)
NTK_x = ((torch.einsum('nc,mc->nm', [grads_left_x, grads_left_x])))
np.save(str(0.03)+"_"+str(Shot_per_class)+"_"+str(Perc_prior)+"_cnn_.npy",[prior_train,prior_trainls,prior_trainls,NTK_x,ll,rr,NTK_x])



