from _base import *
from _model import *
import argparse
from torch.utils.data import TensorDataset, DataLoader
from _hp_ddd import lr_dic, k_dic
import os


path = os.getcwd()
np.set_printoptions(precision=3, suppress=True)

parser = argparse.ArgumentParser()

# non-fixed hyper-parameter
parser.add_argument('--file', default='concrete', type=str, help='file name')
parser.add_argument('--k', default=8, type=float, help='lambda')
parser.add_argument('--epoch', default=1000, type=int, help='n_epoch')
parser.add_argument('--out', default=25, type=int, help='n_output')
parser.add_argument('--lr', default=0.003, type=float, help='learning rate')
parser.add_argument('--dropout_p', default=0, type=float, help='dropout probability')

# fixed hyper-parameter
parser.add_argument('--batch', default=100, type=int, help='n_batch')
parser.add_argument('--gmm_h1', default=50, type=int, help='hidden1 of gmm')
parser.add_argument('--gmm_h2', default=50, type=int, help='hidden2 of gmm')
parser.add_argument('--ddd_h1', default=50, type=int, help='hidden1 of ddd')
parser.add_argument('--ddd_h2', default=50, type=int, help='hidden2 of ddd')
parser.add_argument('--wd', default=0, type=float, help='weight decay')
parser.add_argument('--gmm', default=5, type=int, help='n_gmm')
parser.add_argument('--ens', default=5, type=int, help='n_ens')
parser.add_argument('--tau', default=0.95, type=int, help='target probability')

args = parser.parse_args()

print(args)

args.lr = lr_dic[args.file] # use the same learning rate used in GMM
args.k = k_dic[args.file]

if args.file == 'protein':
    seed_list = range(1, 6)
    args.gmm_h1 = 100
    args.gmm_h2 = 100
    args.ddd_h1 = 100
    args.ddd_h2 = 100

elif args.file == 'naval':
    print('naval')
    args.epoch = 2000
    patience = np.inf
    seed_list = range(1, 6)
else:
    seed_list = range(1, 21)

cost_va_list = np.array([])
picp_va_list = np.array([])
mpiw_va_list = np.array([])

picp_te_list = np.array([])
mpiw_te_list = np.array([])

with torch.cuda.device(0):
    for seed in seed_list:
        np.random.seed(seed)
        X_tr, X_va, X_te, Y_tr, Y_va, Y_te, y_al, y_range = load_data(args.file, seed, 0.1, 0.1)

        m_ens_tr = torch.FloatTensor([]).cuda()
        s_ens_tr = torch.FloatTensor([]).cuda()
        p_ens_tr = torch.FloatTensor([]).cuda()

        m_ens_va = torch.FloatTensor([]).cuda()
        s_ens_va = torch.FloatTensor([]).cuda()
        p_ens_va = torch.FloatTensor([]).cuda()

        m_ens_te = torch.FloatTensor([]).cuda()
        s_ens_te = torch.FloatTensor([]).cuda()
        p_ens_te = torch.FloatTensor([]).cuda()

        for ens_num in range(args.ens):

            gmm = GMM(n_feature=X_te.shape[1], n_hidden1=args.gmm_h1, n_hidden2=args.gmm_h2, n_gmm=args.gmm, dropout_prob=args.dropout_p)
            gmm.load_state_dict(
            torch.load(path + '/model/gmm/'+ str(args.file) + '/'  + str(seed) +  '-' + str(ens_num) + '.tar'))
            gmm.cuda()
            gmm.eval()

            m_tr, s_tr, p_tr = gmm(X_tr)
            m_va, s_va, p_va = gmm(X_va)
            m_te, s_te, p_te = gmm(X_te)

            m_ens_tr = torch.cat((m_ens_tr, m_tr.detach()), dim=1)
            s_ens_tr = torch.cat((s_ens_tr, s_tr.detach()), dim=1)
            p_ens_tr = torch.cat((p_ens_tr, p_tr.detach()), dim=1)

            m_ens_va = torch.cat((m_ens_va, m_va.detach()), dim=1)
            s_ens_va = torch.cat((s_ens_va, s_va.detach()), dim=1)
            p_ens_va = torch.cat((p_ens_va, p_va.detach()), dim=1)

            m_ens_te = torch.cat((m_ens_te, m_te.detach()), dim=1)
            s_ens_te = torch.cat((s_ens_te, s_te.detach()), dim=1)
            p_ens_te = torch.cat((p_ens_te, p_te.detach()), dim=1)

        m_ens_tr = m_ens_tr[torch.arange(p_ens_tr.shape[0])[:, None], torch.argsort(p_ens_tr, dim=1, descending=True)]
        s_ens_tr = s_ens_tr[torch.arange(p_ens_tr.shape[0])[:, None], torch.argsort(p_ens_tr, dim=1, descending=True)]
        p_ens_tr = p_ens_tr[torch.arange(p_ens_tr.shape[0])[:, None], torch.argsort(p_ens_tr, dim=1, descending=True)]

        m_ens_va = m_ens_va[torch.arange(p_ens_va.shape[0])[:, None], torch.argsort(p_ens_va, dim=1, descending=True)]
        s_ens_va = s_ens_va[torch.arange(p_ens_va.shape[0])[:, None], torch.argsort(p_ens_va, dim=1, descending=True)]
        p_ens_va = p_ens_va[torch.arange(p_ens_va.shape[0])[:, None], torch.argsort(p_ens_va, dim=1, descending=True)]

        m_ens_te = m_ens_te[torch.arange(p_ens_te.shape[0])[:, None], torch.argsort(p_ens_te, dim=1, descending=True)]
        s_ens_te = s_ens_te[torch.arange(p_ens_te.shape[0])[:, None], torch.argsort(p_ens_te, dim=1, descending=True)]
        p_ens_te = p_ens_te[torch.arange(p_ens_te.shape[0])[:, None], torch.argsort(p_ens_te, dim=1, descending=True)]

        dataset = TensorDataset(m_ens_tr, s_ens_tr, p_ens_tr, Y_tr)
        dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True)

        es = EarlyStopping(patience=1000, verbose=1)
        loss_va_save = torch.FloatTensor([1000])

        ddd = DDD(3*args.ens*args.gmm, args.ddd_h1, args.ddd_h2, args.out)
        ddd.cuda()
        optimizer = torch.optim.Adam(ddd.parameters(), lr=args.lr, weight_decay=args.wd)

        for epoch in range(args.epoch):
            if epoch % 10 == 0:
                print (epoch)

            for batch_idx, train_batch in enumerate(dataloader):

                m_ens_tr_batch, s_ens_tr_batch, p_ens_tr_batch, Y_tr = train_batch

                m = m_ens_tr_batch
                s = s_ens_tr_batch
                p = p_ens_tr_batch

                ddd.train()

                L, U = ddd.forward(m, s, p)
                prob = ddd.cal_prob(L, U, m, s, p)/args.ens
                loss_tr = ddd.loss(L, U, prob, args.k)

                optimizer.zero_grad()
                loss_tr.backward(retain_graph=True)
                optimizer.step()

                # validation check

                ddd.eval()

                L_va, U_va = ddd.forward(m_ens_va, s_ens_va, p_ens_va)
                prob_va = ddd.cal_prob(L_va, U_va, m_ens_va, s_ens_va, p_ens_va)/args.ens

                loss_va = ddd.loss(L_va, U_va, prob_va, args.k)
                acc_va = ddd.cal_acc(L_va, U_va, Y_va)
                mpiw_va = ddd.cal_mpiw(L_va, U_va)
                earlystop = es.validate(loss_va.item())

                if loss_va.item() < loss_va_save:
                    loss_va_save = loss_va

                    ddd.eval()
                    L_te, U_te = ddd.forward(m_ens_te, s_ens_te, p_ens_te)
                    prob_te = ddd.cal_prob(L_te, U_te, m_ens_te, s_ens_te, p_ens_te) / args.ens
                    loss_te = ddd.loss(L_te, U_te, prob_te, args.k)

                    acc_te = ddd.cal_acc(L_te, U_te, Y_te)
                    mpiw_te = ddd.cal_mpiw(L_te, U_te)

                    print('----------saving--------------')
                    print(args)
                    print('seed: ', seed, 'data: ', args.file)
                    print('epoch: ', epoch, 'n_gmm: ', args.gmm, 'k: ', args.k,  'n_out: ', args.out, 'lr: ', args.lr)
                    print('loss_save: ', loss_va_save.item())

                    print('------------------------------------')
                    print('valid_acc: ', acc_va.item())
                    print('valid_mpiw: ', mpiw_va.item() )
                    print('------------------------------------')

                    torch.save(ddd.state_dict(), path + '/model/ddd/ens/' + str(args.file) + '/' + str(seed) + '.tar')
                    save_acc = acc_te
                    save_mpiw = mpiw_te

                    print('---width saving---')
                    print('accuracy: ', save_acc.item())
                    print('width: ', save_mpiw.item())

                    print('------------------------------------')
                    print(cost_va_list)
                    print(np.mean(cost_va_list))
                    print(picp_va_list)
                    print(np.mean(picp_va_list))
                    print(mpiw_va_list)
                    print(np.mean(mpiw_va_list))
                    print('------------------------------------')
                    print(picp_te_list)
                    print(np.mean(picp_te_list))
                    print(mpiw_te_list)
                    print(np.mean(mpiw_te_list))
                    print('------------------------------------')

            if earlystop:
                print('break')
                cost_va_list = np.append(cost_va_list, loss_va_save.item())
                picp_va_list = np.append(picp_va_list, acc_va.item())
                print(np.mean(picp_va_list), np.std(picp_va_list))
                mpiw_va_list = np.append(mpiw_va_list, mpiw_va.item())
                print(np.mean(mpiw_va_list), np.std(mpiw_va_list))
                print(args)
                picp_te_list = np.append(picp_te_list, save_acc.item())
                print(picp_te_list)
                print(np.mean(picp_te_list), np.std(picp_te_list))
                mpiw_te_list = np.append(mpiw_te_list, save_mpiw.item())
                print(mpiw_te_list)
                print(np.mean(mpiw_te_list), np.std(mpiw_te_list))
                break
