from _base import *
from _model import *
import numpy as np
import torch
from _hp_mve import lr_dic
from torch.utils.data import TensorDataset, DataLoader
import argparse
import os

path = os.getcwd()
parser = argparse.ArgumentParser()

parser.add_argument('--file', default='boston', type=str, help='file name')
parser.add_argument('--epoch', default=3000, type=int, help='n_epoch')
parser.add_argument('--batch_size', default=100, type=int, help='batch_size')
parser.add_argument('--gmm', default=1, type=int, help='n_gmm')
parser.add_argument('--dropout_p', default=0, type=float, help='dropout probability')
parser.add_argument('--ens_num', default=5, type=int, help='n_ensemble')
parser.add_argument('--wd', default=0, type=float, help='weight decay')
parser.add_argument('--patience', default=3000, type=float, help='early stopping patience')
parser.add_argument('--gmm_h1', default=50, type=int, help='hidden1 of gmm')
parser.add_argument('--gmm_h2', default=50, type=int, help='hidden2 of gmm')


args = parser.parse_args()


if args.file == 'protein':
    batch_size = 1000
    seed_list = range(1, 6)
    args.gmm_h1 = 100
    args.gmm_h2 = 100

elif args.file =='naval':
    patience = np.inf
    seed_list = range(1, 6)

else:
    batch_size = args.batch_size
    patience = args.patience

    seed_list = range(1, 21)
    args.gmm_h1 = 50
    args.gmm_h2 = 50

print(args)
n_epochs = args.epoch

with torch.cuda.device(0):
    result = {}  # set empty dic for result

    loss_list = np.array([])
    picp_list = np.array([])
    mpiw_list = np.array([])
    acc_list = np.array([])
    width_list = np.array([])
    for seed in seed_list:
        np.random.seed(seed)  # set seed
        X_tr, X_va, X_te, Y_tr, Y_va, Y_te, Y_al, y_range = load_data(args.file, seed, 0.1, 0.1)
        dataset = TensorDataset(X_tr, Y_tr)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        mu_list = torch.FloatTensor([]).cuda()
        std_list = torch.FloatTensor([]).cuda()

        for ens_num in range(args.ens_num):

            es = EarlyStopping(patience=patience, verbose=1)
            loss_va_save = torch.FloatTensor([1000])

            net = MVE(n_feature=X_tr.shape[1], n_hidden1=args.gmm_h1, n_hidden2=args.gmm_h2, n_gmm=args.gmm,
                      dropout_prob=args.dropout_p)
            net.cuda()
            optimizer = torch.optim.Adam(net.parameters(), lr=lr_dic[args.file], weight_decay=args.wd)

            for epoch in range(args.epoch):

                for batch_idx, train_batch in enumerate(dataloader):
                    X_batch, Y_batch = train_batch

                    net.train()
                    loss = net.loss(X_batch, Y_batch)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    # validation loss
                    net.eval()
                    loss_va = net.loss(X_va, Y_va)

                    earlystop = es.validate(loss_va.item())

                    if loss_va.item() < loss_va_save:
                        loss_va_save = loss_va.item()
                        loss_te_save = net.loss(X_te, Y_te).item()

                        mu, std = net(X_te)

                        U = mu + 1.96 * std
                        L = mu - 1.96 * std

                        picp = torch.mean(torch.sum((L.lt(Y_te) * Y_te.lt(U)).float(), dim=1))
                        mpiw = torch.mean(torch.sum(U - L, dim=1))
                        # learning status
                        print('data: ', args.file)
                        print(args)
                        print('---------------------------')
                        print('n: ', args.gmm, 'seed: ', seed, 'lr: ', lr_dic[args.file], 'ens_num: ', ens_num)
                        print('Epoch {:4d}/{} Cost: {:.6f}'.format(epoch, n_epochs, loss_va.item()))
                        print('val_cost: ', loss_va_save)
                        print('test_cost: ', loss_te_save)
                        print('---------------------------')
                        print(loss_list)
                        print(np.mean(loss_list))
                        print('PICP: ', picp)
                        print('MPIW: ', mpiw)
                        print('acc_list: ', acc_list)
                        print(np.mean(acc_list), np.std(acc_list))
                        print('width_list: ', width_list)
                        print(np.mean(width_list), np.std(width_list))
                        torch.save(net.state_dict(),
                                   path + '/model/mve/' + str(args.file) + '/' + str(seed) +
                                   '-' + str(ens_num) + '.tar')

                if earlystop:
                    print('break')
                    loss_list = np.append(loss_list, loss_te_save)
                    picp_list = np.append(picp_list, picp.item())
                    mpiw_list = np.append(mpiw_list, mpiw.item())
                    mu_list = torch.cat([mu_list, mu], dim=1)
                    std_list = torch.cat([std_list, std], dim=1)
                    break

            if earlystop == 0:
                loss_list = np.append(loss_list, loss_te_save)
                picp_list = np.append(picp_list, picp.item())
                mpiw_list = np.append(mpiw_list, mpiw.item())
                mu_list = torch.cat([mu_list, mu], dim=1)
                std_list = torch.cat([std_list, std], dim=1)

        mu = torch.mean(mu_list, dim=1)
        std = torch.sqrt(torch.mean(mu_list**2 + std_list**2, dim=1) - mu**2)
        U = (mu + 1.96 * std).reshape(-1, 1)
        L = (mu - 1.96 * std).reshape(-1, 1)
        acc = torch.mean((L.lt(Y_te) * Y_te.lt(U)).float())
        acc_list = np.append(acc_list, acc.item())
        width = torch.mean(U-L)
        width_list = np.append(width_list, width.item())
print('acc_list: ', acc_list)
print(np.mean(acc_list), np.std(acc_list))
print('width_list: ', width_list)
print(np.mean(width_list), np.std(width_list))

print('Done')