import argparse
import numpy as np
np.set_printoptions(precision=4)

import torch
from models import InterpRealNVP
import util
from loader import DataLoader
from models import batch_EM

from joblib import load, dump
from tqdm import tqdm
from pathlib import Path

class EM_scheduler:
    def __init__(self, step_size=3, gamma=0.8, init=0.99):
        self.rate_opt = torch.optim.SGD(torch.nn.Linear(2, 1).parameters(), lr=init)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.rate_opt, step_size, gamma)
    
    def step(self):
        self.rate_opt.step()
        self.scheduler.step()
        return self.rate_opt.param_groups[0]["lr"]

def main(args):
    # initialize dataset class
    ldr = DataLoader(args, mode=0)
    data_loader = torch.utils.data.DataLoader(ldr, batch_size=args.batch_size, shuffle=True, drop_last=False)

    #Initialize normalizing flow model neural network and its optimizer
    num_neurons = int(ldr.train[0].shape[0])
    flow = util.init_flow_model(num_neurons, args.num_nf_layers, InterpRealNVP, num_neurons)
    nf_optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=args.lr)
    nf_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(nf_optimizer, T_max=3, eta_min=5e-4)

    const_scheduler = EM_scheduler(step_size=3, gamma=0.8, init=0.99)
    EM = batch_EM(gamma = args.gamma, gamma_scheduler=None,
                  const=0.99, const_scheduler=const_scheduler,
                  beta = args.beta)

    print("\n*********************************")
    print("Starting {} experiment\n".format(args.data))

    te_mse_lst = []
    loss1_lst = []
    loss2_lst = []
    loss3_lst = []

    save_dir = "./results/mnist/"
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    for epoch in range(args.n_epochs):
        loss1, loss2, loss3 = util.endtoend_train_superbatch(flow=flow, 
                                                             batch_EM=EM, 
                                                             nf_optimizer=nf_optimizer, 
                                                             nf_scheduler=nf_scheduler, 
                                                             loader=data_loader,
                                                             args=args)
        loss1_lst.append(loss1)
        loss2_lst.append(loss2)
        loss3_lst.append(loss3)

        # not a necessary step for training
        # show how test RMSE decreases
        with torch.no_grad():
            ldr.mode=1 #Use testing data
            te_mse = util.endtoend_test(flow, EM, data_loader, args) #Test MCFlow model
            te_mse_lst.append(te_mse)
            ldr.mode=0 #Use training data
        print("Epoch {:<2}: L1: {:9.2f} || L2: {:9.2f} || L3: {:9.2f} || Test RMSE: {:.5f}".\
            format(epoch, loss1, loss2, loss3, te_mse**.5))

        ldr.reset_imputed_values(EM, flow, args)
        # save the imputed test set
        dump(ldr.test, save_dir + 'test{}_m{}_s{}.joblib'.format(epoch+1, args.drp_percent, args.seed))
        flow = util.init_flow_model(num_neurons, args.num_nf_layers, InterpRealNVP, num_neurons)
        nf_optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=args.lr)
        if args.reset:
            EM.reset()
        EM.rate_init()
    # save history
    history = {'loss1': loss1_lst, 'loss2': loss2_lst, 'loss3': loss3_lst, 'mse': te_mse_lst}
    dump(history, save_dir + 'history_m{}_s{}.joblib'.format(args.drp_percent, args.seed))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', default='mnist')
    # setting seed somehow doesn't work for image datasets
    parser.add_argument('--seed', type=int, default=1000, help='Reproducibility')
    parser.add_argument('--drp-percent', type=float, default=0.3)

    parser.add_argument('--batch-size', type=int, default=512)
    parser.add_argument('--batch-size-test', type=int, default=2048)
    parser.add_argument('--super-size', type=int, default=3000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--n-epochs', type=int, default=20)
    parser.add_argument('--grad-clip', type=float, default=50.)

    parser.add_argument('--num-nf-layers', type=int, default=3)
    parser.add_argument('--gamma', type=float, default=0.8)
    parser.add_argument('--alpha', type=float, default=5e8)
    parser.add_argument('--beta', type=float, default=0.)
    parser.add_argument('--reset', type=util.str2bool, default=False)
    args = parser.parse_args()

    ''' Reproducibility '''
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    main(args)
    print("Experiment completed")
