from _base import *
from _model import *
import numpy as np
import torch
from _hp_ddd import lr_dic
from torch.utils.data import TensorDataset, DataLoader
import argparse

path = os.getcwd()
parser = argparse.ArgumentParser()

parser.add_argument('--file', default='concrete', type=str, help='file name')
parser.add_argument('--epoch', default=3000, type=int, help='n_epoch')
parser.add_argument('--batch_size', default=100, type=int, help='batch_size')
parser.add_argument('--gmm', default=5, type=int, help='n_gmm')
parser.add_argument('--dropout_p', default=0, type=float, help='dropout probability')
parser.add_argument('--ens_num', default=5, type=int, help='n_ensemble')
parser.add_argument('--wd', default=0, type=float, help='weight decay')
parser.add_argument('--patience', default=3000, type=float, help='early stopping patience')
parser.add_argument('--gmm_h1', default=50, type=int, help='hidden1 of gmm')
parser.add_argument('--gmm_h2', default=50, type=int, help='hidden2 of gmm')


args = parser.parse_args()

if args.file =='naval':
    patience = np.inf
    seed_list = range(1, 6)
    batch_size = args.batch_size

elif args.file == 'protein':
    batch_size = 1000
    seed_list = range(1, 6)
    args.gmm_h1 = 100
    args.gmm_h2 = 100
    patience = args.patience

else:
    batch_size = args.batch_size
    seed_list = range(1, 21)
    patience = args.patience

print(args)

with torch.cuda.device(0):
    result = {}  # set empty dic for result

    loss_list = np.array([])


    for seed in seed_list:
        np.random.seed(seed)  # set seed
        X_tr, X_va, X_te, Y_tr, Y_va, Y_te, Y_al, y_range = load_data(args.file, seed, 0.1, 0.1)
        dataset = TensorDataset(X_tr, Y_tr)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for ens_num in range(args.ens_num):

            es = EarlyStopping(patience=patience, verbose=1)
            loss_va_save = torch.FloatTensor([1000])

            net = GMM(n_feature=X_tr.shape[1], n_hidden1=args.gmm_h1, n_hidden2=args.gmm_h2, n_gmm=args.gmm,
                      dropout_prob=args.dropout_p)
            net.cuda()
            optimizer = torch.optim.Adam(net.parameters(), lr=lr_dic[args.file], weight_decay=args.wd)

            for epoch in range(args.epoch):
                if epoch % 10 == 0:
                    print('epoch: ', epoch)

                for batch_idx, train_batch in enumerate(dataloader):
                    X_batch, Y_batch = train_batch

                    net.train()
                    # E-step
                    mu, std, p = net(X_batch)
                    w = p * norm_pdf(Y_batch, mu, std)
                    w = w / (torch.sum(w, dim=1).reshape(-1, 1) + 1e-7)
                    w = w.detach()

                    # M-step
                    loss = w * torch.log(p + 1e-7)
                    loss += w * lognorm_pdf(Y_batch, mu, std)
                    loss = torch.sum(loss, dim=1)
                    loss = torch.mean(-loss)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    # validation loss
                    net.eval()
                    loss_va = net.loss(X_va, Y_va)

                    earlystop = es.validate(loss_va.item())

                    if loss_va.item() < loss_va_save:
                        loss_va_save = loss_va.item()
                        loss_te_save = net.loss(X_te, Y_te).item()

                        # learning status
                        print('data: ', args.file)
                        print(args)
                        print('---------------------------')
                        print('n: ', args.gmm, 'seed: ', seed, 'lr: ', lr_dic[args.file], 'ens_num: ', ens_num)
                        print('Epoch {:4d}/{} Cost: {:.6f}'.format(epoch, args.epoch, loss_va.item()))
                        print('val_cost: ', loss_va_save)
                        print('test_cost: ', loss_te_save)
                        print('---------------------------')
                        print(loss_list)
                        print(np.mean(loss_list))
                        torch.save(net.state_dict(),
                                   path + '/model/gmm/' + str(args.file) + '/' + str(seed) +
                                   '-' + str(ens_num) + '.tar')

                if earlystop:
                    print('break')
                    loss_list = np.append(loss_list, loss_te_save)
                    break

            if earlystop == 0:
                loss_list = np.append(loss_list, loss_te_save)

