from _base import *
from _model import *
import argparse
from torch.utils.data import TensorDataset, DataLoader
from _hp_mve import lr_dic
import os
path = os.getcwd()

print('----cal_prob----')

np.set_printoptions(precision=3, suppress=True)

parser = argparse.ArgumentParser()

parser.add_argument('--file', default='protein', type=str, help='file name')
parser.add_argument('--gmm_h1', default=50, type=int, help='hidden1 of gmm')
parser.add_argument('--gmm_h2', default=50, type=int, help='hidden2 of gmm')
parser.add_argument('--gmm', default=1, type=int, help='n_gmm')
parser.add_argument('--ens', default=5, type=int, help='n_ens')
parser.add_argument('--tau', default=0.95, type=int, help='target probability')

args = parser.parse_args()

print(args)

if args.file == 'protein':
    seed_list = range(1, 6)
    args.gmm_h1 = 100
    args.gmm_h2 = 100
    args.ddd_h1 = 100
    args.ddd_h2 = 100

elif args.file == 'naval':
    seed_list = range(1, 6)
    args.epoch = 2000
else:
    seed_list = range(1, 21)




picp_te_list = np.array([])
mpiw_te_list = np.array([])

with torch.cuda.device(1):
    for seed in seed_list:
        print('seed: ', seed)
        np.random.seed(seed)
        X_tr, X_va, X_te, Y_tr, Y_va, Y_te, y_al, y_range = load_data(args.file, seed, 0.1, 0.1)

        mu_list = torch.FloatTensor([]).cuda()
        std_list = torch.FloatTensor([]).cuda()

        for ens_num in range(args.ens):

            net = MVE(n_feature=X_te.shape[1], n_hidden1=args.gmm_h1, n_hidden2=args.gmm_h2, n_gmm=args.gmm)
            net.load_state_dict(
            torch.load(path + '/model/mve/' + str(args.file) + '/' + str(seed) +
                                   '-' + str(ens_num) + '.tar'))
            net.cuda()
            net.eval()

            m_te, s_te = net(X_te)

            mu_list = torch.cat((mu_list, m_te.detach()), dim=1)
            std_list = torch.cat((std_list, s_te.detach()), dim=1)



        mu = torch.mean(mu_list, dim=1)
        std = torch.sqrt(torch.mean(mu_list**2 + std_list **2, dim=1) - mu **2)
        U = (mu + 1.96 * std).reshape(-1, 1)
        L = (mu - 1.96 * std).reshape(-1, 1)
        acc = torch.mean((L.lt(Y_te) * Y_te.lt(U)).float())
        width = torch.mean(U-L)


        picp_te_list = np.append(picp_te_list, acc.item())
        mpiw_te_list = np.append(mpiw_te_list, width.item())

print('file: ', args.file)
print(picp_te_list)
print('PCIP: ', np.mean(picp_te_list), np.std(picp_te_list))
print(mpiw_te_list)
print('MPIW: ', np.mean(mpiw_te_list), np.std(mpiw_te_list))
