from _base import *
from _model import *
import scipy.stats as stats
import os
import argparse
from torch.utils.data import TensorDataset, DataLoader
from _hp_ddd_single import lr_dic, k_dic


path = os.getcwd()
np.set_printoptions(precision=3, suppress=True)

parser = argparse.ArgumentParser()

# non-fixed hyper-parameter
parser.add_argument('--mode', default=2, type=int, help='sqr or not')
parser.add_argument('--file', default='bike', type=str, help='file name')
parser.add_argument('--k', default=20, type=float, help='lambda')
parser.add_argument('--epoch', default=500, type=int, help='n_epoch')
parser.add_argument('--out', default=5, type=int, help='n_output')
parser.add_argument('--lr', default=0.007, type=float, help='learning rate')
parser.add_argument('--dropout_p', default=0, type=float, help='dropout probability')

# fixed hyper-parameter
parser.add_argument('--batch', default=100, type=int, help='n_batch')
parser.add_argument('--gmm_h1', default=50, type=int, help='hidden1 of gmm')
parser.add_argument('--gmm_h2', default=50, type=int, help='hidden2 of gmm')
parser.add_argument('--ddd_h1', default=50, type=int, help='hidden1 of ddd')
parser.add_argument('--ddd_h2', default=50, type=int, help='hidden2 of ddd')
parser.add_argument('--wd', default=0, type=float, help='weight decay')
parser.add_argument('--gmm', default=5, type=int, help='n_gmm(fixed in this experiment)')
parser.add_argument('--ens', default=5, type=int, help='n_ens(fixed in this experiment)')
parser.add_argument('--tau', default=0.95, type=int, help='n_gmm(fixed in this experiment)')

args = parser.parse_args()
args.lr = lr_dic[args.file]  # use the same learning rate used in GMM
args.k = k_dic[args.file]

if args.file == 'protein':
    seed_list = range(1, 6)
    args.batch = 1000
    args.gmm_h1 = 100
    args.gmm_h2 = 100
    args.ddd_h1 = 100
    args.ddd_h2 = 100

elif args.file == 'naval':
    print('naval')
    args.epoch = 2000
    patience = np.inf
    seed_list = range(1, 6)

else:
    seed_list = range(1, 6)

print(args)

cost_va_list = np.array([])
picp_va_list = np.array([])
mpiw_va_list = np.array([])

picp_te_list = np.array([])
mpiw_te_list = np.array([])

with torch.cuda.device(0):
    for seed in seed_list:
        np.random.seed(seed)
        X_tr, X_va, X_te, Y_tr, Y_va, Y_te, y_al, y_range = load_data(args.file, seed, 0.1, 0.1)

        for ens_num in range(5):
            es = EarlyStopping(patience=3000, verbose=1)
            loss_va_save = torch.FloatTensor([1000])

            gmm = GMM(n_feature=X_tr.shape[1], n_hidden1=args.gmm_h1, n_hidden2=args.gmm_h2, n_gmm=5,
                      dropout_prob=0)
            gmm.load_state_dict(
            torch.load(path + '/model/gmm/' + str(args.file) + '/' + str(seed) + '-' + str(ens_num) + '.tar'))
            gmm.cuda()
            gmm.eval()

            m_tr, s_tr, p_tr = gmm(X_tr)
            m_va, s_va, p_va = gmm(X_va)
            m_te, s_te, p_te = gmm(X_te)

            dataset = TensorDataset(m_tr, s_tr, p_tr)
            dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True)

            es = EarlyStopping(patience=3000, verbose=1)
            loss_va_save = torch.FloatTensor([10000])

            ddd = DDD(3*args.gmm, args.ddd_h1, args.ddd_h2, args.out)
            ddd.cuda()
            optimizer = torch.optim.Adam(ddd.parameters(), lr=args.lr, weight_decay=args.wd)

            for epoch in range(args.epoch):
                if epoch % 10 == 0:
                    print (epoch)

                for batch_idx, train_batch in enumerate(dataloader):
                    m, s, p = train_batch  # m = m_tr_batch, s = s_tr_batch,  p = p_tr_batch

                    ddd.train()
                    L, U = ddd.forward(m, s, p)
                    prob = ddd.cal_prob(L, U, m, s, p)
                    loss_tr = ddd.loss(L, U, prob, args.k)

                    optimizer.zero_grad()
                    loss_tr.backward(retain_graph=True)
                    optimizer.step()

                    # validation check

                    ddd.eval()

                    L_va, U_va = ddd.forward(m_va, s_va, p_va)
                    prob_va = ddd.cal_prob(L_va, U_va, m_va, s_va, p_va)
                    loss_va = ddd.loss(L_va, U_va, prob_va, args.k)
                    acc_va = ddd.cal_acc(L_va, U_va, Y_va)
                    mpiw_va = ddd.cal_mpiw(L_va, U_va)

                    if loss_va.item() < loss_va_save:
                        loss_va_save = loss_va

                        ddd.eval()
                        L_te, U_te = ddd.forward(m_te, s_te, p_te)
                        prob_te = ddd.cal_prob(L_te, U_te, m_te, s_te, p_te)
                        loss_te = ddd.loss(L_te, U_te, prob_te, args.k)

                        acc_te = ddd.cal_acc(L_te, U_te, Y_te)
                        mpiw_te = ddd.cal_mpiw(L_te, U_te)

                        print('----------saving--------------')
                        print(args)
                        print('seed: ', seed, 'data: ', args.file)
                        print('epoch: ', epoch, 'n_gmm: ', args.gmm, 'k: ', args.k,  'n_out: ', args.out, 'lr: ', args.lr)
                        print('loss_save: ', loss_va_save.item())

                        print('------------------------------------')
                        print('valid_acc: ', acc_va.item())
                        print('valid_mpiw: ', mpiw_va.item() )
                        print('------------------------------------')

                        torch.save(ddd.state_dict(), path + '/model/ddd/single/' + str(args.file) + '/' +
                                   str(seed) + '-' + str(ens_num) + '-' + 'compare' + '.tar')
                        save_acc = acc_te
                        save_mpiw = mpiw_te

                        print('---width saving---')
                        print('accuracy: ', save_acc.item())
                        print('width: ', save_mpiw.item())

                        print('------------------------------------')
                        print(cost_va_list)
                        print(np.mean(cost_va_list))
                        print(picp_va_list)
                        print(np.mean(picp_va_list))
                        print(mpiw_va_list)
                        print(np.mean(mpiw_va_list))
                        print('------------------------------------')
                        print(picp_te_list)
                        print(np.mean(picp_te_list))
                        print(mpiw_te_list)
                        print(np.mean(mpiw_te_list))
                        print('------------------------------------')

            cost_va_list = np.append(cost_va_list, loss_va_save.item())
            picp_va_list = np.append(picp_va_list, acc_va.item())
            print(np.mean(picp_va_list), np.std(picp_va_list))
            mpiw_va_list = np.append(mpiw_va_list, mpiw_va.item())
            print(np.mean(mpiw_va_list), np.std(mpiw_va_list))
            print(args)
            picp_te_list = np.append(picp_te_list, save_acc.item())
            print(picp_te_list)
            print(np.mean(picp_te_list), np.std(picp_te_list))
            mpiw_te_list = np.append(mpiw_te_list, save_mpiw.item())
            print(mpiw_te_list)
            print(np.mean(mpiw_te_list), np.std(mpiw_te_list))
