import argparse
import numpy as np
np.set_printoptions(precision=4)

import torch
from models import InterpRealNVP, batch_EM
import util
from loader import DataLoader

from joblib import load, dump
from tqdm import tqdm
from pathlib import Path


def main(args):
    # initialize dataset class
    ldr = DataLoader(args, mode=0)
    data_loader = torch.utils.data.DataLoader(ldr, batch_size=args.batch_size, shuffle=True, drop_last=False)

    #Initialize normalizing flow model neural network and its optimizer
    num_neurons = int(ldr.train[0].shape[0])
    flow = util.init_flow_model(num_neurons, args.num_nf_layers, InterpRealNVP, num_neurons)
    nf_optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=args.lr)
    EM = batch_EM(gamma = args.gamma, beta = args.beta)

    print("\n*********************************")
    print("Starting {} experiment\n".format(args.data))


    te_mse_lst = []
    loss1_lst = []
    loss2_lst = []
    loss3_lst = []
    
    save_dir = "./results/uci/{}/".format(args.data)
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    
    for epoch in range(args.n_epochs):
        loss1, loss2, loss3 = util.endtoend_train(flow=flow, 
                                                  batch_EM=EM, 
                                                  nf_optimizer=nf_optimizer, 
                                                  loader=data_loader, 
                                                  args=args)
        loss1_lst.append(loss1)
        loss2_lst.append(loss2)
        loss3_lst.append(loss3)

        # not a necessary step for training
        # show how test RMSE decreases
        with torch.no_grad():
            ldr.mode=1 #Use testing data
            te_mse = util.endtoend_test(flow, EM, data_loader, args) #Test MCFlow model
            te_mse_lst.append(te_mse)
            ldr.mode=0 #Use training data
        print("Epoch {:<2}: L1: {:9.2f} || L2: {:9.2f} || L3: {:9.2f} || Test RMSE: {:.5f}".\
            format(epoch, loss1, loss2, loss3, te_mse**.5))

        ldr.reset_imputed_values(EM, flow, args)
        if epoch != args.n_epochs - 1:
            flow = util.init_flow_model(num_neurons, args.num_nf_layers, InterpRealNVP, num_neurons)
            nf_optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=args.lr)
            if args.reset:
                EM.reset()
            EM.rate_init()
    # save history
    history = {'loss1': loss1_lst, 'loss2': loss2_lst, 'loss3': loss3_lst, 'mse': te_mse_lst}
    dump(history, save_dir + 'history_f{}_{}_s{}_batch{}_beta{:.0e}_g{}_r{}.joblib'.format(args.fold_id, args.miss_pattern, 
                                                                                           args.seed, args.batch_size, 
                                                                                           args.beta, args.gamma,
                                                                                           int(args.reset)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', default='news')
    parser.add_argument('--seed', type=int, default=1000, help='Reproducibility')
    parser.add_argument('--drp-percent', type=float, default=0.2)
    parser.add_argument('--fold-id', type=int, default=0)
    parser.add_argument('--miss-pattern', default='MCAR')

    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--batch-size-test', type=int, default=10000)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--n-epochs', type=int, default=25)
    parser.add_argument('--grad-clip', default=None)

    parser.add_argument('--num-nf-layers', type=int, default=3)
    parser.add_argument('--gamma', type=float, default=0.8)
    parser.add_argument('--alpha', type=float, default=1e6)
    parser.add_argument('--beta', type=float, default=0.)
    parser.add_argument('--reset', type=util.str2bool, default=True)
    args = parser.parse_args()

    ''' Reproducibility '''
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    main(args)
    print("Experiment completed")
    
