from _base import *
from _model import *
import argparse
from torch.utils.data import TensorDataset, DataLoader
from _hp_ddd import lr_dic, k_dic
import os


print('----cal_prob----')
path = os.getcwd()
np.set_printoptions(precision=3, suppress=True)

parser = argparse.ArgumentParser()

parser.add_argument('--file', default='boston', type=str, help='file name')
parser.add_argument('--out', default=25, type=int, help='n_output')
parser.add_argument('--gmm_h1', default=50, type=int, help='hidden1 of gmm')
parser.add_argument('--gmm_h2', default=50, type=int, help='hidden2 of gmm')
parser.add_argument('--ddd_h1', default=50, type=int, help='hidden1 of ddd')
parser.add_argument('--ddd_h2', default=50, type=int, help='hidden2 of ddd')
parser.add_argument('--gmm', default=5, type=int, help='n_gmm')
parser.add_argument('--ens', default=5, type=int, help='n_ens')
parser.add_argument('--tau', default=0.95, type=int, help='target probability')

args = parser.parse_args()

print(args)

args.lr = lr_dic[args.file] # use the same learning rate used in GMM
args.k = k_dic[args.file]

if args.file == 'protein':
    seed_list = range(1, 6)
    args.gmm_h1 = 100
    args.gmm_h2 = 100
    args.ddd_h1 = 100
    args.ddd_h2 = 100

elif args.file == 'naval':
    seed_list = range(1, 6)


else:
    seed_list = range(1, 21)


picp_te_list = np.array([])
mpiw_te_list = np.array([])

with torch.cuda.device(0):
    for seed in seed_list:
        print('seed: ', seed)
        np.random.seed(seed)
        X_tr, X_va, X_te, Y_tr, Y_va, Y_te, y_al, y_range = load_data(args.file, seed, 0.1, 0.1)

        m_ens_tr = torch.FloatTensor([]).cuda()
        s_ens_tr = torch.FloatTensor([]).cuda()
        p_ens_tr = torch.FloatTensor([]).cuda()

        m_ens_va = torch.FloatTensor([]).cuda()
        s_ens_va = torch.FloatTensor([]).cuda()
        p_ens_va = torch.FloatTensor([]).cuda()

        m_ens_te = torch.FloatTensor([]).cuda()
        s_ens_te = torch.FloatTensor([]).cuda()
        p_ens_te = torch.FloatTensor([]).cuda()

        for ens_num in range(args.ens):

            gmm = GMM(n_feature=X_te.shape[1], n_hidden1=args.gmm_h1, n_hidden2=args.gmm_h2, n_gmm=args.gmm)
            gmm.load_state_dict(
            torch.load(path + '/model/gmm/' + str(args.file) + '/' + str(seed) + '-' + str(ens_num) + '.tar'))
            gmm.cuda()
            gmm.eval()

            m_tr, s_tr, p_tr = gmm(X_tr)
            m_va, s_va, p_va = gmm(X_va)
            m_te, s_te, p_te = gmm(X_te)

            m_ens_tr = torch.cat((m_ens_tr, m_tr.detach()), dim=1)
            s_ens_tr = torch.cat((s_ens_tr, s_tr.detach()), dim=1)
            p_ens_tr = torch.cat((p_ens_tr, p_tr.detach()), dim=1)

            m_ens_va = torch.cat((m_ens_va, m_va.detach()), dim=1)
            s_ens_va = torch.cat((s_ens_va, s_va.detach()), dim=1)
            p_ens_va = torch.cat((p_ens_va, p_va.detach()), dim=1)

            m_ens_te = torch.cat((m_ens_te, m_te.detach()), dim=1)
            s_ens_te = torch.cat((s_ens_te, s_te.detach()), dim=1)
            p_ens_te = torch.cat((p_ens_te, p_te.detach()), dim=1)

        m_ens_tr = m_ens_tr[torch.arange(p_ens_tr.shape[0])[:, None], torch.argsort(p_ens_tr, dim=1, descending=True)]
        s_ens_tr = s_ens_tr[torch.arange(p_ens_tr.shape[0])[:, None], torch.argsort(p_ens_tr, dim=1, descending=True)]
        p_ens_tr = p_ens_tr[torch.arange(p_ens_tr.shape[0])[:, None], torch.argsort(p_ens_tr, dim=1, descending=True)]

        m_ens_va = m_ens_va[torch.arange(p_ens_va.shape[0])[:, None], torch.argsort(p_ens_va, dim=1, descending=True)]
        s_ens_va = s_ens_va[torch.arange(p_ens_va.shape[0])[:, None], torch.argsort(p_ens_va, dim=1, descending=True)]
        p_ens_va = p_ens_va[torch.arange(p_ens_va.shape[0])[:, None], torch.argsort(p_ens_va, dim=1, descending=True)]

        m_ens_te = m_ens_te[torch.arange(p_ens_te.shape[0])[:, None], torch.argsort(p_ens_te, dim=1, descending=True)]
        s_ens_te = s_ens_te[torch.arange(p_ens_te.shape[0])[:, None], torch.argsort(p_ens_te, dim=1, descending=True)]
        p_ens_te = p_ens_te[torch.arange(p_ens_te.shape[0])[:, None], torch.argsort(p_ens_te, dim=1, descending=True)]

        ddd = DDD(3*args.ens*args.gmm, args.ddd_h1, args.ddd_h2, args.out)
        ddd.load_state_dict(
            torch.load(path + '/model/ddd/ens/' + str(args.file) + '/' + str(seed) + '.tar'))
        ddd.cuda()
        ddd.eval()

        L_te, U_te = ddd.forward(m_ens_te, s_ens_te, p_ens_te)
        picp_te = ddd.cal_acc(L_te, U_te, Y_te)
        mpiw_te = ddd.cal_mpiw(L_te, U_te)

        picp_te_list = np.append(picp_te_list, picp_te.item())
        mpiw_te_list = np.append(mpiw_te_list, mpiw_te.item())

print('file: ', args.file)
print(picp_te_list)
print('PCIP: ', np.mean(picp_te_list), np.std(picp_te_list))
print(mpiw_te_list)
print('MPIW: ', np.mean(mpiw_te_list), np.std(mpiw_te_list))
