'''Define class that takes in data and train stuff'''
import utils
import nets
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import visualize as viz
import math
import os
import pdb
import AutoEncoder as AE
import data
from torch.distributions.multivariate_normal import MultivariateNormal
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class full_container_simulation():
    def __init__(self, args, checkpoint_kwards=[True, True]):
        args.savename = f'JKO_{args.int_mtd}_{args.netname[:-3]}_Xdim={args.Xdim}_reparam={args.reparam_type}'
        if args.Xdim > 1 and args.word != '':
            if 'img' in args.word:
                prefix = args.word.split('.')[0]
                args.savename += f'_{prefix}_{args.netname}'
            else:
                args.savename += f'_{args.word}_{args.netname}'
            if args.continuous_param:
                args.savename += '_cont_param'
        if args.cond_gen:
            self.param_net_ls = []
        self.args = args
        self.load_checkpoint, self.save_checkpoint = checkpoint_kwards
        ### NEW ###
        # prune-and-refine
        # If True, we map X_b through f_b several times before getting X_b+1
        # We do so by breaking [0,T_b] to C numbers of [0,T_b/C] and flow using the same f_b
        self.refinement = False
        self.eta = 0.5  # For updating T_b along the new direction. Small = less contribution
        ### High-dim data_loader ###
        self.train_loader = None
        self.test_loader = None
        self.X_test = None  # For getting MMD statistics
        self.reload_high_dim_data = True  # If we reload these data
        self.X_test_PCA = None  # For visualizing projected components
        self.alpha_MMD = None
        self.high_dim_data = ['bsds300', 'miniboone', 'power', 'gas']
        #### Some extra argument ####
        # This is only used if we want to check whether different "topological components of different modes are mapped to same place"
        self.args.color_X = False
        self.args.use_NeuralODE = False  # Default use simple backprop

    def training(self):
        # NOTE, they are currently not saved, but it is fine because results at each phase are saved
        self.ls_all_dict, self.FlowNet_dict = {}, {}
        use_refine = self.refinement
        self.refinement = False  # At phase 0, never use refinement
        '''
            Note, the loops below are:
            Over phase -> Over blocks -> Over each epoch -> Over each batch
            In the future, we can have more separate methods to handle these loops to inc. readability.
        '''
        for p in range(self.args.num_phase):
            ''' 0. Start at a given phase '''
            if use_refine and p > 0:
                self.refinement = True
            if p > 0:
                # We choose to always save model trained after the initial phase
                self.save_checkpoint = True
            self.p, self.args.p = p, p
            self.resume_checkpoint()
            args = self.args
            self.get_Tb()  # Get limits of integration at each block. This is a main step where we did the reparametrization
            for j in range(args.num_blocks, args.num_max_blocks):
                start_b = time.time()
                ''' 1. Define network and warm-start/continuously parametrize parameters '''
                if args.netname == 'FCnet':
                    odefunc = nets.FCnet(args)
                if args.netname == 'ODEnet':
                    odefunc = nets.ODEnet(args)
                if args.netname == 'Chebnet':
                    odefunc = nets.GNN(args)
                if args.netname == 'Convnet':
                    odefunc = nets.Convnet(args)
                if args.netname in ['ODEnet', 'Chebnet', 'Convnet']:
                    args.continuous_param = False  # Have not figured out how to do this
                print(f'{nets.get_n_params(odefunc)/1000}K params in Block {j}')
                model = nets.CNF(odefunc).to(device)
                self.FlowNet.append(model)
                # Actual number of block used
                args.num_blocks = len(self.FlowNet)
                # Deal with parameters
                model, optimizer = self.param_operation(args, model)
                if args.word in self.high_dim_data:
                    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    #     optimizer, factor=0.9, patience=1, min_lr=1e-4)
                    scheduler = torch.optim.lr_scheduler.StepLR(
                        optimizer, step_size=int(args.niters/4), gamma=0.9)
                scheduler = None
                print(
                    f'=========== Phase {p+1}: Training of Block {j+1}/{args.num_max_blocks} ============')
                print(f'=========== Data is {args.word} ===========')
                ls = []
                if args.word in self.high_dim_data:
                    ls_test = []
                # train block j
                for itr in range(args.niters):
                    # Print time took per epoch
                    start = time.time()
                    model.train()
                    ''' 2. Get data and feed forward to train JKO loss '''
                    xraw, yraw = self.get_data(args)
                    if j == 0 and itr == 0:
                        # Create the LOSS function corresponding to value of yraw. It includes the special case of non-conditional generation
                        self.get_H_cond_Y(xraw, yraw)
                    lossls, lossDivls, lossVls, lossW2ls = [], [], [], []
                    for data_batch in self.train_loader:
                        xraw, yraw = data_batch
                        # Pass through AutoEncoder (or not)
                        xinput = self.through_AE(
                            xraw.clone().to(device), args, itr, encode=True)
                        if j > 0:
                            # Push through previous blocks
                            # do not need gradients from previous blocks
                            with torch.no_grad():
                                for jj in range(j):
                                    # get args.T based on args.T0
                                    utils.reparam_t(args, jj)
                                    predz, _ = utils.refinement_map(
                                        self.FlowNet[jj], xinput, args, test=True, refinement=self.refinement)
                                    xinput = predz[-1]
                        if args.continuous_param:
                            self.param_net_optim_zero_grad()  # zero out param nets grad
                        else:
                            optimizer.zero_grad()  # train current block
                        # get args.T based on args.T0
                        utils.reparam_t(args, j)
                        predz, dlogpx = utils.refinement_map(
                            model, xinput, args, test=False, refinement=self.refinement)
                        ''' 3. Get losses
                            lossDiv: divergence losss, change in log lik
                            lossV: V_Z(mapping) (or V_{H|Y}(mapping) for cond. gen)
                            lossW2: W2^2 loss
                        '''
                        lossW2 = utils.quick_l2(predz[-1] - predz[0]) / args.T
                        predz, dlogpx = predz[-1], dlogpx[-1]
                        if args.cond_gen:
                            lossV = self.get_V_loss(predz, yraw)/len(yraw)
                        else:
                            lossV = utils.quick_l2(predz)
                        lossDiv = dlogpx.mean()  # divergence losss, change in log lik
                        loss = lossDiv + lossV + lossW2
                        loss.backward()
                        if args.continuous_param:
                            # parameter net grad descent
                            self.param_net_optim_step()
                            model = self.assign_continuous_param()
                        else:
                            optimizer.step()
                        lossls.append(loss.cpu().detach().numpy())
                        lossDivls.append(lossDiv.cpu().detach().numpy())
                        lossVls.append(lossV.cpu().detach().numpy())
                        lossW2ls.append(lossW2.cpu().detach().numpy())
                    lossave, lossDivave, lossVave, lossW2ave = np.mean(lossls), np.mean(
                        lossDivls), np.mean(lossVls), np.mean(lossW2ls)
                    '''
                        4. Visualize X, \hat{X}, Z, \hat{Z} at intermediate stages
                           We may also just check MMD and test log-likelihood on test data
                    '''
                    if itr == 0 or (itr+1) % args.log_interval == 0:
                        val_time = time.time()
                        print(
                            f'Current integral limits are [0,{args.T}] under reparam type {args.reparam_type} for T_b')
                        print(
                            f"Current lr is {optimizer.param_groups[0]['lr']}")
                        print(
                            f'iter {itr} w / {self.num_batches} batches/iterations')
                        print(f'--Ave(over batch) loss sum {lossave}')
                        print(
                            f'--Ave(over batch) loss divergence {lossDivave}')
                        print(f'--Ave(over batch) loss V {lossVave}')
                        print(f'--Ave(over batch) loss W2 {lossW2ave}')
                        num_e = 1 if itr == 0 else args.log_interval
                        print(
                            f'{num_e} epoch(es) with {args.train_data_size} data took {time.time()-start} secs.')
                        # Due to backward integral for visualization, args.T was args.T0
                        utils.reparam_t(args, j)
                        if args.Xdim == 2 or args.cond_gen:
                            # Considered graph generation & conditional generation!
                            viz.for_and_back(
                                self, args, self.base_dist_ls)
                        else:
                            if args.word == 'MNIST':
                                _, _ = viz.vision_data_visualize(
                                    self, args, type='interpolate')
                                _, _ = viz.vision_data_visualize(
                                    self, args, type='random_sample')
                            elif args.word == 'CIFAR10':
                                # No interpolation scheme
                                _, _ = viz.vision_data_visualize(
                                    self, args, type='random_sample')
                            else:
                                # Do not visualize, but compute MMD loss
                                self.use_kde = False
                                if itr == args.niters - 1:
                                    self.use_kde = True
                                MMD_metric_dict, logrhoX = self.get_test_MMD_and_loglik()
                                print('###############Test Metrics###############')
                                for alpha, MMD_metric in MMD_metric_dict.items():
                                    print(
                                        f'--Test MMD loss at alpha={alpha}: {MMD_metric.item():.2e}')
                                print(
                                    f'--Test neg loglikelihood {-logrhoX.item():.2e}')
                                if scheduler is not None:
                                    scheduler.step(MMD_metric)
                        # Increment by time to evaluate during this block
                        args.val_time += (time.time() - val_time)/60
                    ls.append(np.float32(
                        [lossDivave, lossVave, lossW2ave, lossave]))
                    if args.word in self.high_dim_data:
                        ls_test.append(np.float32(
                            [MMD_metric.item(), logrhoX.item()]))
                        _ = viz.plot_MMD_loglik(ls_test, args)
                        plt.show()
                        plt.close()
                print(
                    f'=========== Phase {p+1}: Training Block {j+1}/{args.num_max_blocks} took {time.time()-start_b} secs ===========')
                # Increment by time to train this block
                args.tot_time += (time.time()-start_b)/60
                self.ls_all.append(ls)
                viz.plot_losses(self.ls_all, self.args)
                if args.word in self.high_dim_data:
                    self.ls_all_test.append(ls_test)
                    _ = viz.plot_MMD_loglik(
                        self.ls_all_test, args, end_of_block_iter=True)
                    plt.show()
                    plt.close()
                ''' 5. Save checkpoint '''
                if self.save_checkpoint:
                    # In minutes of training all blocks so far
                    self.modify_saving(save=True)
                # Termination criteria by comparing W2
                Tcurr = args.T
                utils.reparam_t(args, 0)  # First W2
                if args.early_stop and np.array(self.ls_all[-1])[:, 2].mean() * Tcurr <= np.array(self.ls_all[0])[:, 2].mean() * args.T / 1000:
                    # Namely, current W2 error <= 0.1% of initial W2 error
                    print(
                        f'Training terminated, {j+1} blocks out of {args.num_max_blocks} max blocks used')
                    break
            self.ls_all_dict[p] = self.ls_all
            # Need to copy this list, as o/w weights warm-start can be incorrect
            self.FlowNet_dict[p] = self.FlowNet.copy()

    ''' For high-dimensional data '''

    def get_test_MMD_and_loglik(self):
        '''
            1). As in OT-Flow, compute MMD loss on test data to compare the distritbuions
            2). As in FFJORD, compute log likelihood on test data, by using the change in log likelihood from X to Z and then subtract
        '''
        X_test = self.through_AE(self.X_test.to(
            device), self.args, 0, encode=True)
        N1, d = X_test.shape
        torch.manual_seed(1103)
        indexes = torch.randperm(N1)
        X_test = X_test[indexes]
        num_blocks = len(self.FlowNet)
        # 1. Get test loglikelihood
        Z_test_hat_ls, dlogpx_ls = utils.map_for_or_back(
            X_test, num_blocks, self.FlowNet, self.args, reverse=False, return_dlogpx=True)
        logrhoXtoZ = dlogpx_ls.sum()  # -divf, which gives us density change from X to Z
        Zhat = Z_test_hat_ls[-1]  # Estimate of normal
        # Quickly check inversion error
        X_test_inv = utils.map_for_or_back(
            Zhat, num_blocks, self.FlowNet, self.args, reverse=True, return_dlogpx=False)
        X_test_inv = X_test_inv[-1]
        abs_err = utils.quick_l2(X_test-X_test_inv).pow(0.5)
        print(f'--Test absolute inversion error is {abs_err.item():.2e}')
        constant = -d/2 * math.log(2*math.pi)
        # Loglikelihood of normal
        logrhoZ = - utils.quick_l2(Zhat) + constant
        logrhoX = logrhoZ - logrhoXtoZ
        self.args.loglik_test = logrhoX
        ################################
        Z = torch.randn(N1, d).to(device)
        X_test_hat_ls = utils.map_for_or_back(
            Z, num_blocks, self.FlowNet, self.args, reverse=True, return_dlogpx=False)
        X_test_hat = X_test_hat_ls[-1]
        twoD_ls = ['img_rose.png', '', 'img_tree.png', 'img_sierpinski_hard.png',
                   'img_sierpinski.png', 'img_5rings.png', 'img_checkerboard.png']
        if self.args.word in twoD_ls:
            plt.rcParams['axes.titlesize'] = 24
            # Just scatter plot
            viz.quick_scatter(self, X_test, X_test_hat)
        else:
            # PCA
            viz.get_PCA_plot(self, X_test, X_test_hat)
        # NOTE, the alpha here is the same as what OT-flow used
        ################################
        # 2. Get MMD
        utils.get_MMD_dict(self, X_test, X_test_hat)
        MMD_metric_dict = self.args.MMD_test
        ################################
        # 3. Slice and plot a few dimensions
        if self.args.word == 'miniboone':
            dim1, dim2 = 16, 17
            viz.slice_data_plt(X_test, X_test_hat, dim1, dim2)
        return MMD_metric_dict, logrhoX

    def through_AE(self, xinput, args, itr, encode=True):
        if args.Xdim > 2 and args.word in ['MNIST', 'CIFAR10']:
            # Feature is flattened
            if itr == 0:
                if args.word == 'MNIST':
                    sPATH = 'MNIST_autoencode'
                    autoEnc = AE.Autoencoder(
                        encoding_dim=self.args.Xdim).to(device)
                if args.word == 'CIFAR10':
                    sPATH = 'CIFAR10_autoencode'
                    autoEnc = AE.AutoencoderCIFAR10(
                        encoding_dim=self.args.Xdim).to(device)
                AE_dir = os.path.join(
                    sPATH, f'autoenc_checkpt_d={args.Xdim}.pth')
                checkpt = torch.load(AE_dir, map_location='cpu')
                autoEnc.load_state_dict(checkpt["state_dict"], strict=False)
                self.autoEnc = autoEnc
                self.autoEnc.eval()  # sort of important
            with torch.no_grad():
                if args.word == 'MNIST':
                    xinput = xinput.view(xinput.shape[0], -1)
                if encode:
                    xinput = self.autoEnc.encode(xinput)
                else:
                    xinput = self.autoEnc.decode(xinput)
            return xinput
        else:
            # Do nothing
            return xinput

    ''' Repametrization '''

    def get_Tb(self):
        args = self.args
        rtype = args.reparam_type
        NumB = args.num_max_blocks
        if rtype == 'constant':
            args.T_ls = [args.T0 for _ in range(NumB)]
        if rtype == 'multiple':
            args.T_ls = [args.T0*1.2**i for i in range(NumB)]
        if rtype == 'exponential':
            # Can be a bit conservative if num_max_blocks large
            args.T_ls = np.exp(np.log(args.num_int_pts)*np.arange(NumB)/NumB)
        if rtype == 'adaptive_W2':
            # Nearly the same as vector-space reparam.
            if self.p > 0:
                # We get W2 from the last epoch of all blocks, but take the square root to get 'arc length'
                # This is (2*Tb)^{-1}* W2^2
                lossW2old = np.array([loss[-1][2]
                                      for loss in self.ls_all_dict[self.p-1]])
                T_old_ls = np.array(args.T_dict[self.p-1])
                W2_allblocks = np.sqrt(2*T_old_ls*lossW2old)
                Wbar = W2_allblocks.mean()
                T_new_ls = Wbar*T_old_ls/W2_allblocks
                tildeW2 = np.cumsum(W2_allblocks[::-1])[::-1]
                smallW2 = tildeW2 <= Wbar/2
                self.args.num_max_blocks -= smallW2.sum()
                T_old_ls = T_old_ls[~smallW2]
                T_new_ls = T_new_ls[~smallW2]
                args.T_ls = T_old_ls + self.eta*(T_new_ls-T_old_ls)
            else:
                args.T_ls = [args.T0 for _ in range(NumB)]
        args.T_dict[self.p] = args.T_ls

    ''' Either Warm start (with previous block or same block from previous phse)
        Or Continuously parmetrize (i.e., Weight sharing)'''

    def param_operation(self, args, model):
        if args.continuous_param:
            if len(self.param_net_ls) == 0:
                # Get nets that output parameters
                self.get_param_net_ls()
            # Assign continuous parameters using parameter nets
            model = self.assign_continuous_param()
            # NOTE: the CNF themselves have NO trainable parameters (because parameters) are outputs of the parameter net, so we cannot assign any optimizer
            optimizer = 0
        else:
            # Warm start latest block with previous block's weights
            self.assign_warm_start_param()
            optimizer = torch.optim.Adam(
                model.parameters(), lr=args.lr)
        return model, optimizer

    def assign_warm_start_param(self):
        if self.p > 0:
            # Warm start with the corresponding block of the previous phase
            self.FlowNet[-1].load_state_dict(
                self.FlowNet_dict[self.p-1][len(self.FlowNet)-1].state_dict())
        else:
            # Warm start latest block with previous block's weights
            if len(self.FlowNet) > 1:
                self.FlowNet[-1].load_state_dict(self.FlowNet[-2].state_dict())

    def assign_continuous_param(self):
        '''
            Re-create CNF model based on the param_net_ls.
                This is inevitable because nn.Module.parameters() do not allow outputs from other networks AND remain trainable.
            ToDo: better determine how we should increment t
            Each block integrates from 0 to 1, but incrementing t by 1 seems too large
            Now I let t at each block be increments^2 of how far each of its steps move (later blocks also would move more)
        '''
        args = self.args
        t_block = torch.zeros(1).to(device)
        for j, model_old in enumerate(self.FlowNet):
            utils.reparam_t(args, j)
            t_block = t_block + (args.T/args.num_int_pts)**2
            param_ls_j = []
            for i, param in enumerate(model_old.parameters()):
                param_net, param_shape = self.param_net_ls[i][0], self.param_net_ls[i][-1]
                theta_i = param_net(t_block).reshape(param_shape)
                param_ls_j.append(theta_i)
            self.args.param_ls = param_ls_j
            odefunc_new = nets.FCnet(self.args)
            self.FlowNet[j] = nets.CNF(odefunc_new).to(device)
        # This is because now the FlowNet has been modified and we need the current model
        model = self.FlowNet[-1]
        return model

    ''' Helpers for weight sharing '''

    def get_param_net_ls(self):
        '''
            In particular, all blocks have the same structure so we just run this once
            What this does is to create parameter nets for all parameters in the block
            Then later on, it will output the continuous parameters using these nets
            So we can define new nets that share these parameters
        '''
        args = self.args
        model = self.FlowNet[0]  # initial (or common) block
        for param in model.parameters():
            num_param = param.numel()
            param_shape = param.shape
            param_net = nets.param_net_class(num_param).to(device)
            param_optim = torch.optim.Adam(param_net.parameters(), lr=args.lr)
            self.param_net_ls.append([param_net, param_optim, param_shape])

    def param_net_optim_zero_grad(self):
        for combo in self.param_net_ls:
            param_optim = combo[1]
            param_optim.zero_grad()

    def param_net_optim_step(self):
        # Sanity check of whether param net has gradient
        for combo in self.param_net_ls:
            pnet = combo[0]
            for i, param in enumerate(pnet.parameters()):
                if param.grad is None:
                    print('Some grad of param net is None')
        for combo in self.param_net_ls:
            param_optim = combo[1]
            param_optim.step()

    ''' Get V_{H|Y} loss based on condition Y '''

    def get_V_loss(self, predz, yraw_batch):
        '''
            Get log-prob based on H|Y below
        '''
        predz = predz.flatten(start_dim=0, end_dim=1)
        yraw_batch = yraw_batch.flatten(start_dim=0, end_dim=1)
        unique_Y = torch.unique(yraw_batch)
        lossV = 0
        for i in unique_Y:
            idx_i = yraw_batch == i
            lossV = lossV + \
                self.base_dist_ls[int(i.cpu().detach().numpy())].log_prob(
                    predz[idx_i]).sum()
        return -lossV

    def get_H_cond_Y(self, xraw, yraw):
        '''
            Here, it is for changing V_Z, Z ~ N(0,I) to V_{H|Y}, H|Y ~ N(\mu|Y,\sigma*I)
        '''
        # Initialize with the mean of data
        if len(torch.unique(yraw)) == 1:
            # No conditional distribution
            gen_dim = xraw.shape[1]
            base_mu = torch.zeros(gen_dim).to(device)
            base_cov = torch.diag(torch.ones(gen_dim)).to(device)
            self.base_dist_ls = [MultivariateNormal(base_mu, base_cov)]
        else:
            '''
                Note, this is explicitly written assuming X = (N,V,C), where X_i[v] \in \R^C
                So H|Y \in \R^C
            '''
            if len(self.param_net_ls) == 0:
                Y_unique = torch.unique(yraw)
                Y_dim = len(Y_unique)  # Because Y is one-hot-encoded
                gen_dim = xraw.shape[2]
                mean_vectors = []
                for i in range(Y_dim):
                    # mean_i = torch.transpose(xraw, 1, 2)[yraw == i].mean(axis=0)
                    mean_i = 1.5*xraw[yraw == i].mean(axis=0)
                    mean_vectors.append(mean_i)
                mean_vectors = torch.vstack(mean_vectors).to(device)
                dist_mat = torch.cdist(mean_vectors, mean_vectors)
                # Get minimum non-zero pairwise distance to decide the covariance
                min_dist = dist_mat[dist_mat > 0].min().cpu().detach().numpy()
                # By conc. inequality, if X = (X_1,...,X_d) where X_i ~ N(0,sigma),
                # Then E[\|X\|_2] = \sqrt{d}*\sigma
                # Thus, with very high prob., \|X\|_2 \in ball of radius 3*\sqrt{d}*\sigma
                # As we have two such balls, we would need \sigma < dist/(6*\sqrt{d})
                C = 15*np.sqrt(gen_dim)
                scaled_dist = min_dist/C
                base_cov = (torch.eye(gen_dim) * scaled_dist).to(device)
                self.gen_net = nets.SmallGenNet(Y_dim, gen_dim).to(device)
                with torch.no_grad():
                    self.gen_net.fc.weight = torch.nn.Parameter(
                        torch.transpose(mean_vectors, 0, 1).to(device))
                    self.gen_net.fc.bias = torch.nn.Parameter(
                        torch.zeros(gen_dim).to(device))
                self.base_dist_ls = []
                for i in range(Y_dim):
                    basis_i = torch.zeros(Y_dim).to(device)
                    basis_i[i] = 1
                    # NOTE: if do NOT detach, would cause "backward again" error
                    base_mu = self.gen_net(basis_i).detach().to(device)
                    self.base_dist_ls.append(
                        MultivariateNormal(base_mu, base_cov))

    ''' Other helpers '''

    def get_data(self, args):
        if args.Xdim > 2 and args.cond_gen == False or args.word == 'MNIST':
            vision_data = ['MNIST', 'CIFAR10']
            if args.word in vision_data:
                # Repeat at each epoch to get better results
                self.train_loader, self.test_loader = data.vision_data(
                    batch_size=args.batch_size, frac=args.frac, dataname=args.word)
                args.train_data_size = 60000
                if self.X_test is None:
                    X_test = []
                    for test_batch in self.test_loader:
                        X_test.append(test_batch[0])
                    self.X_test = torch.vstack(X_test).to(device)
            else:
                # Get train and test datasets used in OT-Flow and MAF
                if self.reload_high_dim_data:
                    # We load data only once due to data size, but randomly permute data at each iteration
                    self.train_full, self.X_test = data.tensor_high_dim(
                        args.word, frac=args.frac)
                    args.train_data_size = self.train_full.shape[0]
                    self.reload_high_dim_data = False
                batches = self.get_batches(self.train_full)
                y_placeholder = torch.zeros(args.train_data_size).to(device)
                self.train_loader = self.SimpleDloader(
                    self.train_full, y_placeholder, batches)
            xraw, yraw = torch.rand(
                1, args.Xdim), torch.zeros(1)  # Placeholder
        else:
            if args.cond_gen:
                args.w, args.h = 1, 1
                if args.word == 'two_moon':
                    self.xraw, self.yraw = utils.inf_train_gen_cond_gen(args)
                else:
                    # Real solar data
                    if self.reload_high_dim_data:
                        self.xraw, self.yraw, self.edge_index = data.get_solar_data()
                        args.train_data_size = self.xraw.shape[0]
                        self.reload_high_dim_data = False
                    for mod in self.FlowNet:
                        mod.odefunc.edge_index = self.edge_index.to(device)
            else:
                self.xraw = utils.inf_train_gen(args)
                self.yraw = torch.zeros(len(self.xraw))
                self.X_test = utils.inf_train_gen(args).to(device)
            xraw, yraw = self.xraw.to(device), self.yraw.to(device)
            batches = self.get_batches(xraw)
            self.train_loader = self.SimpleDloader(xraw, yraw, batches)
        self.num_batches = len(self.train_loader)
        return xraw, yraw

    def SimpleDloader(self, x, y, batches):
        dloader = []
        for batch_idx in batches:
            dloader_tmp = [x[batch_idx], y[batch_idx]]
            dloader.append(dloader_tmp)
        return dloader

    def get_batches(self, xfull):
        bsize = self.args.batch_size
        train_tot = xfull.shape[0]
        num_b = int(train_tot / bsize)
        idxes = np.random.choice(range(train_tot), train_tot, replace=False)
        # idxes = np.arange(train_tot)
        cum, batches = 0, []
        for i in range(num_b):
            batches.append(idxes[cum:cum + bsize])
            cum += bsize
        if train_tot - cum > 0:
            batches.append(idxes[cum:])
        return batches

    def resume_checkpoint(self):
        args = self.args
        filepath = f'{args.savename}_phase{self.p}'
        if self.load_checkpoint and os.path.exists(filepath):
            max_b = args.num_max_blocks
            self.modify_saving(save=False, filepath=filepath)
            # To train more blocks
            self.args.num_max_blocks = max_b
            self.args.early_stop = True if self.args.word == '' else False
        else:
            self.args.num_blocks = 0
            self.args.tot_time, self.args.val_time = 0, 0
            self.ls_all = []  # Record training losses
            if args.word in self.high_dim_data:
                self.ls_all_test = []  # New, record MMD metric & test loglik
            self.FlowNet = []
            # For continuous parameter initialization
            # This will contain continuously updated parameters for each specific block. Empty list means we do not build such blocks. It is true at the beginning for cont. param. because we need a reference block for building the parameter nets
            args.param_ls = []
            self.param_net_ls = []

    def modify_saving(self, save=True, filepath=''):
        args = self.args
        if save:
            args.approx_train_time = args.tot_time-args.val_time
            checkpoint = {'Losses': self.ls_all,
                          'args': args}
            param_ls = [mod.state_dict() for mod in self.FlowNet]
            checkpoint['Full_result'] = param_ls
            if args.word in self.high_dim_data:
                checkpoint['Test_metric'] = self.ls_all_test
            if args.continuous_param:
                checkpoint['Param_nets'] = self.param_net_ls
            if args.cond_gen:
                checkpoint['base_dist_ls'] = self.base_dist_ls
            torch.save(checkpoint, f'{args.savename}_phase{self.p}')
        else:
            checkpoint = torch.load(filepath, map_location=device)
            self.ls_all, self.args = checkpoint['Losses'], checkpoint['args']
            num_b = len(checkpoint['Full_result'])
            self.FlowNet = []
            for b in range(num_b):
                if args.netname == 'FCnet':
                    odefunc = nets.FCnet(args)
                if args.netname == 'ODEnet':
                    odefunc = nets.ODEnet(args)
                    args.continuous_param = False  # Have not figured out how to do this
                if args.netname == 'Chebnet':
                    odefunc = nets.GNN(args)
                    args.continuous_param = False  # Have not figured out how to do this
                    # odefunc.reshape = True
                self.FlowNet.append(nets.CNF(odefunc).to(device))
                self.FlowNet[-1].load_state_dict(checkpoint['Full_result'][b])
            if args.word in self.high_dim_data:
                self.ls_all_test = checkpoint['Test_metric']
            if args.continuous_param:
                self.param_net_ls = checkpoint['Param_nets']
            if args.cond_gen:
                self.base_dist_ls = checkpoint['base_dist_ls']

###########
###########
###########
###########
###########
###########
###########
###########
###########
###########
###########
###########
