import time
import copy
import torch
from Pn_util import *
from gae_score_estimation import *
from gae_pd_score_estimation import *
from util import *
from Pn_DataUtil import *

def dae_Pn_trainer(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, use_minibatch = False, scheduler = None, saveAfter = None, 
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    x = trainDataLoader.dataset.train_data.cuda()
    N = x.shape[0]
    
    ####################
    traininput = trainDataLoader.dataset.train_data.cuda()
    X = vec2mat(trainDataLoader.dataset.train_data.cuda())
    X_inv = torch.inverse(X)
    metricInv_sqrt_train = metricInv_sqrt_P_n(X, X_inv)
    metricInv_train = metricInv_P_n(X)
    metricInvDeriv_train = metricInvDeriv_P_n(X)
    christoffel_sum_train = christoffelSum_P_n(X, X_inv)
    christoffel_sumDeriv_train = christoffelSumDeriv_P_n(X, X_inv)
    
    # variable to store 'estimated' score estimation error (on training data if  valIdx is None, on validation data otherwise)
    gscore_est_error_set = []
    if testdataset is not None:
        testinput = testdataset.train_data.cuda()
        X_test = vec2mat(testdataset.train_data.cuda())
        X_test_inv = torch.inverse(X_test)
        metricInv_sqrt_test = metricInv_sqrt_P_n(X_test, X_test_inv)
        metricInv_test = metricInv_P_n(X_test)
        metricInvDeriv_test = metricInvDeriv_P_n(X_test)
        christoffel_sum_test = christoffelSum_P_n(X_test, X_test_inv)
        christoffel_sumDeriv_test = christoffelSumDeriv_P_n(X_test, X_test_inv)
        gscore_est_error_testset = []
        if testscore is not None:
            gscore_error_testset = []
    ####################
    
    # sample fixed noise
    epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, model.noise_std)
    if not use_gpu:
        x = x.cpu()
        traininput = traininput.cpu()
        metricInv_sqrt_train = metricInv_sqrt_train.cpu()
        metricInv_train = metricInv_train.cpu()
        metricInvDeriv_train = metricInvDeriv_train.cpu()
        christoffel_sum_train = christoffel_sum_train.cpu()
        christoffel_sumDeriv_train = christoffel_sumDeriv_train.cpu()
        epsilon = epsilon.cpu()
        if testdataset is not None:
            testinput = testinput.cpu()
            metricInv_sqrt_test = metricInv_sqrt_test.cpu()
            metricInv_test = metricInv_test.cpu()
            metricInvDeriv_test = metricInvDeriv_test.cpu()
            christoffel_sum_test = christoffel_sum_test.cpu()
            christoffel_sumDeriv_test = christoffel_sumDeriv_test.cpu()

    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        valinput = traininput[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        metricInv_val = metricInv_train[valIdx]
        metricInvDeriv_val = metricInvDeriv_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        christoffel_sumDeriv_val = christoffel_sumDeriv_train[valIdx]
        
        # train data
        traininput = traininput[trainIdx]
        x = x[trainIdx]
        X = X[trainIdx]
        X_inv = X_inv[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        metricInv_train = metricInv_train[trainIdx]
        metricInvDeriv_train = metricInvDeriv_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        christoffel_sumDeriv_train = christoffel_sumDeriv_train[trainIdx]
        N = x.shape[0]
        epsilon = epsilon[trainIdx]
        
        # set new train data loader
        Pndataset2 = PndataFromMat(X)
        trainDataLoader = torch.utils.data.DataLoader(Pndataset2, batch_size=trainDataLoader.batch_size, 
                                              shuffle=True, num_workers = 2)
        
    # print initial loss
    loss = model.calculate_loss(x, epsilon)
    print('initial loss:', loss.item()/N)
    
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(x)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        if use_minibatch:
            for ii, (data, _, _) in enumerate(trainDataLoader, 0):
                optimizer.zero_grad()
                loss = model.calculate_loss(data.cuda())
                loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
        else:
            optimizer.step(closure)
            if scheduler is not None:
                scheduler.step()
        if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):
            if valIdx is None:
                est_train = dae_estimate_score(traininput, model) - christoffel_sum_train
                estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_train
                cur_error = estimate_gscore_error(est_train, estDeriv_train, metricInv_train, metricInv_sqrt_train, 
                                  metricInvDeriv_train, christoffel_sum_train, diagonal_metric=False)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = dae_estimate_score(valinput, model) - christoffel_sum_val
                estDeriv_val = dae_estimate_score_deriv(valinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_val
                cur_error = estimate_gscore_error(est_val, estDeriv_val, metricInv_val, metricInv_sqrt_val, 
                                  metricInvDeriv_val, christoffel_sum_val, diagonal_metric=False)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = dae_estimate_score(testinput, model) - christoffel_sum_test
                estDeriv_test = dae_estimate_score_deriv(testinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_test
                cur_testerror = estimate_gscore_error(est_test, estDeriv_test, metricInv_test, 
                metricInv_sqrt_test, metricInvDeriv_test, christoffel_sum_test, diagonal_metric=False)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = torch.bmm((testscore - est_test).view(-1,1,x.shape[1]), metricInv_sqrt_test).view(-1,x.shape[1])
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
                    
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            loss = model.calculate_loss(x, epsilon)
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, float(loss.item()/N)), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, gscore_est_error_set, gscore_est_error_testset
    return best_model, gscore_est_error_set

def rcae_Pn_trainer(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, use_minibatch = True, scheduler = None, saveAfter = None, 
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, 
                    return_lossTrj = False, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    x = trainDataLoader.dataset.train_data.cuda()
    N = x.shape[0]
    
    ####################
    traininput = trainDataLoader.dataset.train_data.cuda()
    X = vec2mat(trainDataLoader.dataset.train_data.cuda())
    X_inv = torch.inverse(X)
    metricInv_sqrt_train = metricInv_sqrt_P_n(X, X_inv)
    metricInv_train = metricInv_P_n(X)
    metricInvDeriv_train = metricInvDeriv_P_n(X)
    christoffel_sum_train = christoffelSum_P_n(X, X_inv)
    christoffel_sumDeriv_train = christoffelSumDeriv_P_n(X, X_inv)
    
    lossTrj = []
    lossTrj_valset = []
    lossTrj_testset = []
    gscore_est_error_set = []
    if testdataset is not None:
        testinput = testdataset.train_data.cuda()
        N_test = testinput.shape[0]
        X_test = vec2mat(testdataset.train_data.cuda())
        X_test_inv = torch.inverse(X_test)
        metricInv_sqrt_test = metricInv_sqrt_P_n(X_test, X_test_inv)
        metricInv_test = metricInv_P_n(X_test)
        metricInvDeriv_test = metricInvDeriv_P_n(X_test)
        christoffel_sum_test = christoffelSum_P_n(X_test, X_test_inv)
        christoffel_sumDeriv_test = christoffelSumDeriv_P_n(X_test, X_test_inv)
        gscore_est_error_testset = []
        if testscore is not None:
            gscore_error_testset = []
    ####################
    
    if not use_gpu:
        x = x.cpu()
        traininput = traininput.cpu()
        metricInv_sqrt_train = metricInv_sqrt_train.cpu()
        metricInv_train = metricInv_train.cpu()
        metricInvDeriv_train = metricInvDeriv_train.cpu()
        christoffel_sum_train = christoffel_sum_train.cpu()
        christoffel_sumDeriv_train = christoffel_sumDeriv_train.cpu()
        if testdataset is not None:
            testinput = testinput.cpu()
            metricInv_sqrt_test = metricInv_sqrt_test.cpu()
            metricInv_test = metricInv_test.cpu()
            metricInvDeriv_test = metricInvDeriv_test.cpu()
            christoffel_sum_test = christoffel_sum_test.cpu()
            christoffel_sumDeriv_test = christoffel_sumDeriv_test.cpu()

    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        valinput = traininput[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        metricInv_val = metricInv_train[valIdx]
        metricInvDeriv_val = metricInvDeriv_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        christoffel_sumDeriv_val = christoffel_sumDeriv_train[valIdx]
        N_val = valinput.shape[0]
        
        # train data
        traininput = traininput[trainIdx]
        x = x[trainIdx]
        X = X[trainIdx]
        X_inv = X_inv[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        metricInv_train = metricInv_train[trainIdx]
        metricInvDeriv_train = metricInvDeriv_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        christoffel_sumDeriv_train = christoffel_sumDeriv_train[trainIdx]
        N = x.shape[0]
        
        # set new train data loader
        Pndataset2 = PndataFromMat(X)
        trainDataLoader = torch.utils.data.DataLoader(Pndataset2, batch_size=trainDataLoader.batch_size, 
                                              shuffle=True, num_workers = 2)
        
    # print initial loss
    loss = model.calculate_loss(x)
    print_info('initial loss: {:f}'.format(loss.item()/N), logger)
    lossTrj.append(loss.item()/N)
    if testdataset is not None:
        lossTrj_testset.append(model.calculate_loss(testinput).item()/N_test)
    if valIdx is not None:
        lossTrj_valset.append(model.calculate_loss(valinput).item()/N_val)
        
    start = time.time()
    for epoch in range(max_iter_num):
        if use_minibatch:
            for ii, (data, _, _) in enumerate(trainDataLoader, 0):
                optimizer.zero_grad()
                loss = model.calculate_loss(data.cuda())
                loss.backward()
                optimizer.step()
                if return_lossTrj:
                    loss_avg = model.calculate_loss(x).item()/N
                    lossTrj.append(loss_avg)
                    if testdataset is not None:
                        lossTrj_testset.append(model.calculate_loss(testinput).item()/N_test)
                    if valIdx is not None:
                        lossTrj_valset.append(model.calculate_loss(valinput).item()/N_val)
                if scheduler is not None:
                    scheduler.step()
        else:
            optimizer.zero_grad()
            loss = model.calculate_loss(x)
            loss.backward()
            optimizer.step()
            loss_avg = loss.item()/N
            if return_lossTrj:
                lossTrj.append(loss_avg)
                if testdataset is not None:
                    lossTrj_testset.append(model.calculate_loss(testinput).item()/N_test)
                if valIdx is not None:
                    lossTrj_valset.append(model.calculate_loss(valinput).item()/N_val)
            if scheduler is not None:
                scheduler.step()
        if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):
            if valIdx is None:
                est_train = dae_estimate_score(traininput, model) - christoffel_sum_train
                estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_train
                cur_error = estimate_gscore_error(est_train, estDeriv_train, metricInv_train, metricInv_sqrt_train, 
                                  metricInvDeriv_train, christoffel_sum_train, diagonal_metric=False)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = dae_estimate_score(valinput, model) - christoffel_sum_val
                estDeriv_val = dae_estimate_score_deriv(valinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_val
                cur_error = estimate_gscore_error(est_val, estDeriv_val, metricInv_val, metricInv_sqrt_val, 
                                  metricInvDeriv_val, christoffel_sum_val, diagonal_metric=False)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = dae_estimate_score(testinput, model) - christoffel_sum_test
                estDeriv_test = dae_estimate_score_deriv(testinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_test
                cur_testerror = estimate_gscore_error(est_test, estDeriv_test, metricInv_test, 
                metricInv_sqrt_test, metricInvDeriv_test, christoffel_sum_test, diagonal_metric=False)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = torch.bmm((testscore - est_test).view(-1,1,x.shape[1]), metricInv_sqrt_test).view(-1,x.shape[1])
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
                    
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            if not return_lossTrj and use_minibatch:
                loss_avg = model.calculate_loss(x).item()/N
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, loss_avg), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    lossTrjs = (lossTrj, lossTrj_valset, lossTrj_testset)
    
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset
    return best_model, lossTrjs, gscore_est_error_set
        
def gae_Pn_trainer(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, use_minibatch = False, scheduler = None, saveAfter = None,
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, augment_weight=None, valIdx = None, 
                 expandOutput = False, duplicate_num = None, gdae_mode = None, RBF_idx = None, returnLastModel = False, loggingFileName = None):
    if gdae_mode == 'RBF' and RBF_idx is None:
        raise Exception("RBF_idx should be given")
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
    
    x = trainDataLoader.dataset.train_data.cuda()
    X = vec2mat(trainDataLoader.dataset.train_data.cuda())
    eps = 1e-14
    S, U = batch_eigsym(X)
    S[S<eps] = eps
    if gdae_mode == 'LogInput':
        # caution: set x as Log_X for short code... better idea?
        x = mat2vec(Log_mat(X, S = S, U = U))
        
    X_sqrt = trainDataLoader.dataset.train_data_sqrt.cuda()
    X_invsqrt = trainDataLoader.dataset.train_data_invsqrt.cuda()
    
    N = x.shape[0]
    vec_dim = x.shape[1]
    dim = X.shape[1]
    
    ####################
    metric_train = metric_P_n(X)
    metricInv_train = metricInv_P_n(X)
    metricInv_sqrt_train = metricInv_sqrt_P_n(X)
    metricDeriv_train = metricDeriv_P_n(X)
    christoffel_sum_train = christoffelSum_P_n(X)
    tempdir = torch.cuda.FloatTensor(x.shape).zero_()
    X_sqrt_dirderiv_set = torch.cuda.FloatTensor(N, dim, dim, vec_dim).zero_()
    dLog_xdx = torch.cuda.FloatTensor(N, vec_dim, vec_dim).zero_()
    for i in range(vec_dim):
        tempdir[:,i] = 1
        Xdot = vec2mat(tempdir)
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
        X_sqrt_dirderiv_set[:,:,:,i] = get_sqrt_sym_DirDeriv(X, Xdot, S = S, U = U, Xdot_trans = Xdot_trans)
        dLog_xdx[:,:,i] = mat2vec(LogDirDeriv(X, Xdot, S = S, U = U, Xdot_trans = Xdot_trans))
        tempdir[:,i] = 0
    
    # variable to store 'estimated' score estimation error (on training data if  valIdx is None, on validation data otherwise)
    gscore_est_error_set = []
    if testdataset is not None:
        x_test = testdataset.train_data.cuda()
        X_test = vec2mat(testdataset.train_data.cuda())
        S_test, U_test = batch_eigsym(X_test)
        S_test[S_test<eps] = eps
        if gdae_mode == 'LogInput':
            # caution: set x as Log_X for short code... better idea?
            x_test = mat2vec(Log_mat(X_test, S = S_test, U = U_test))
        
        X_test_sqrt = testdataset.train_data_sqrt.cuda()
        X_test_invsqrt = testdataset.train_data_invsqrt.cuda()
        metric_test = metric_P_n(X_test)
        metricInv_sqrt_test = metricInv_sqrt_P_n(X_test)
        metricDeriv_test = metricDeriv_P_n(X_test)
        christoffel_sum_test = christoffelSum_P_n(X_test)
        X_test_sqrt_dirderiv_set = torch.cuda.FloatTensor(N, dim, dim, vec_dim).zero_()
        dLog_xdx_test = torch.cuda.FloatTensor(N, vec_dim, vec_dim).zero_()
        for i in range(vec_dim):
            tempdir[:,i] = 1
            Xdot = vec2mat(tempdir)
            Xdot_trans = torch.matmul(torch.matmul(U_test.permute(0,2,1), Xdot), U_test)
            X_test_sqrt_dirderiv_set[:,:,:,i] = get_sqrt_sym_DirDeriv(X_test, Xdot, S = S_test, U = U_test, Xdot_trans = Xdot_trans)
            dLog_xdx_test[:,:,i] = mat2vec(LogDirDeriv(X_test, Xdot, S = S_test, U = U_test, Xdot_trans = Xdot_trans))
            tempdir[:,i] = 0
            
        # variable to store 'estimated' score estimation error for test data
        gscore_est_error_testset = []
        if testscore is not None:
            # variable to store score estimation error for test data
            gscore_error_testset = []
    if expandOutput:
        contract_set = []
        recon_set = []
        drdx_sqnorm_set = []
        dvdx_sqnorm_set = []
    ####################
    
    # sample fixed noise
    epsilon = torch.cuda.FloatTensor(x.size()[0],x.size()[1]).normal_(0.0, model.noise_std)
    if not use_gpu:
        x = x.cpu()
        X = X.cpu()
        X_sqrt = X_sqrt.cpu()
        X_invsqrt = X_invsqrt.cpu()
        metric_train = metric_train.cpu()
        metricInv_train = metricInv_train.cpu()
        metricInv_sqrt_train = metricInv_sqrt_train.cpu()
        metricDeriv_train = metricDeriv_train.cpu()
        christoffel_sum_train = christoffel_sum_train.cpu()
        X_sqrt_dirderiv_set = X_sqrt_dirderiv_set.cpu()
        dLog_xdx = dLog_xdx.cpu()
        epsilon = epsilon.cpu()
        S = S.cpu()
        U = U.cpu()
        if testdataset is not None:
            x_test = x_test.cpu()
            X_test = X_test.cpu()
            X_test_sqrt = X_test_sqrt.cpu()
            X_test_invsqrt = X_test_invsqrt.cpu()
            metric_test = metric_test.cpu()
            metricInv_sqrt_test = metricInv_sqrt_test.cpu()
            metricDeriv_test = metricDeriv_test.cpu()
            christoffel_sum_test = christoffel_sum_test.cpu()
            X_test_sqrt_dirderiv_set = X_test_sqrt_dirderiv_set.cpu()
            dLog_xdx_test = dLog_xdx_test.cpu()
            S_test = S_test.cpu()
            U_test = U_test.cpu()
    
    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        
        x_val = x[valIdx]
        X_val = X[valIdx]
        X_val_sqrt = X_sqrt[valIdx]
        X_val_invsqrt = X_invsqrt[valIdx]
        metric_val = metric_train[valIdx]
        metricInv_val = metricInv_train[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        metricDeriv_val = metricDeriv_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        X_val_sqrt_dirderiv_set = X_sqrt_dirderiv_set[valIdx]
        dLog_xdx_val = dLog_xdx[valIdx]
        S_val = S[valIdx]
        U_val = U[valIdx]
        
        # train data
        x = x[trainIdx]
        X = X[trainIdx]
        X_sqrt = X_sqrt[trainIdx]
        X_invsqrt = X_invsqrt[trainIdx]
        metric_train = metric_train[trainIdx]
        metricInv_train = metricInv_train[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        metricDeriv_train = metricDeriv_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        X_sqrt_dirderiv_set = X_sqrt_dirderiv_set[trainIdx]
        dLog_xdx = dLog_xdx[trainIdx]
        S = S[trainIdx]
        U = U[trainIdx]
        N = x.shape[0]
        epsilon = epsilon[trainIdx]
        
        # set new train data loader
        Pndataset2 = PndataTangentGaussianMixtureExpanded(trainDataLoader.dataset, trainIdx)
        trainDataLoader = torch.utils.data.DataLoader(Pndataset2, batch_size=trainDataLoader.batch_size, 
                                              shuffle=True, num_workers = 2)
        
    if gdae_mode == 'LogInput':
        other_quantities_at_x = [dLog_xdx]
        other_quantities_for_loss_at_x = [S, U]
        if testdataset is not None:
            other_quantities_at_x_test = [dLog_xdx_test]
            other_quantities_for_loss_at_x_test = [S_test, U_test]
        if valIdx is not None:
            other_quantities_at_x_val = [dLog_xdx_val]
            other_quantities_for_loss_at_x_val = [S_val, U_val]
    elif gdae_mode == 'RBF':
        # remove center points in training data later?
        model.set_RBF_centers(X[RBF_idx], X_sqrt[RBF_idx], X_invsqrt[RBF_idx])
        d2fdx2, dfdx, f = model.get_RBF_2nd_derivative(x, X_sqrt, X_invsqrt, metric_train, metricDeriv_train)
        other_quantities_at_x = [f, dfdx]
        other_quantities_for_loss_at_x = [f, dfdx, d2fdx2]
        if testdataset is not None:
            d2fdx2_test, dfdx_test, f_test = model.get_RBF_2nd_derivative(x_test, X_test_sqrt, X_test_invsqrt, metric_test, metricDeriv_test)
            other_quantities_at_x_test = [f_test, dfdx_test]
            other_quantities_for_loss_at_x_test = [f_test, dfdx_test, d2fdx2_test]
        if valIdx is not None:
            d2fdx2_test, dfdx_val, f_val = model.get_RBF_2nd_derivative(x_val, X_val_sqrt, X_val_invsqrt, metric_val, metricDeriv_val)
            other_quantities_at_x_val = [f_val, dfdx_val]
            other_quantities_for_loss_at_x_val = [f_val, dfdx_val, d2fdx2_val]
    else:
        other_quantities_at_x = None
        other_quantities_at_x_test = None
        other_quantities_at_x_val = None
        other_quantities_for_loss_at_x = [S, U]
        if testdataset is not None:
            other_quantities_for_loss_at_x_test = [S_test, U_test]
        if valIdx is not None:
            other_quantities_for_loss_at_x_val = [S_val, U_val]
            
    # print initial loss
    loss = model.calculate_loss(x, X, X_sqrt, X_invsqrt, epsilon, duplicate_num=duplicate_num, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
    if augment_weight is not None:
        loss1 = model.estimate_score_error(x, X_sqrt, christoffel_sum_train, X_sqrt_dirderiv_set, other_quantities_at_x = other_quantities_at_x)
        print_info("initial loss0: {:f} ---- loss1: {:f}".format(float(loss.item()/N), float(loss1.item())), logger)
    else:
        print_info('initial loss: {:f}'.format(float(loss.item()/N)), logger)
    
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(x, X, X_sqrt, X_invsqrt, duplicate_num=duplicate_num, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
        if augment_weight is not None:
            loss += augment_weight * model.estimate_score_error(x, X_sqrt, christoffel_sum_train, X_sqrt_dirderiv_set, other_quantities_at_x = other_quantities_at_x)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        # to avoid numerical instability, use approximation after enough epoches
        if model.exp_approx not in [1,2,3,4] and model.log_approx not in [1,2,3,4]:
            if float(loss.item()/N) < 0.01 or epoch > 10000:
                model.exp_approx = 3
                model.log_approx = 3
        if use_minibatch:
            for ii, data in enumerate(trainDataLoader, 0):
                cur_x, cur_logx, cur_X_sqrt, cur_X_invsqrt, cur_metric, _, cur_metricDeriv, \
                cur_X_sqrt_dirderiv_set, cur_dLog_xdx, cur_christoffel_sum, cur_S, cur_U = data
                optimizer.zero_grad()
                cur_other_quantities_at_x = None
                cur_other_quantities_for_loss_at_x = None
                if gdae_mode == 'RBF':
                    cur_d2fdx2, cur_dfdx, cur_f = model.get_RBF_2nd_derivative(cur_x.cuda(), cur_X_sqrt.cuda(), cur_X_invsqrt.cuda(), 
                                                                               cur_metric.cuda(), cur_metricDeriv.cuda())
                    cur_other_quantities_for_loss_at_x = [cur_f, cur_dfdx, cur_d2fdx2]
                    cur_other_quantities_at_x = [cur_f, cur_dfdx]
                elif gdae_mode == 'LogInput':
                    cur_X = vec2mat(cur_x)
                    cur_x = cur_logx
                    cur_other_quantities_at_x = [cur_dLog_xdx.cuda()]
                    cur_other_quantities_for_loss_at_x = [cur_S.cuda(), cur_U.cuda()]
                else:
                    cur_X = vec2mat(cur_x)
                    cur_other_quantities_for_loss_at_x = [cur_S.cuda(), cur_U.cuda()]
                
                loss = model.calculate_loss(cur_x.cuda(), cur_X.cuda(), cur_X_sqrt.cuda(), cur_X_invsqrt.cuda(), duplicate_num=duplicate_num, 
                                            other_quantities_for_loss_at_x = cur_other_quantities_for_loss_at_x)
                if augment_weight is not None:
                    loss += augment_weight * model.estimate_score_error(cur_x.cuda(), cur_X_sqrt.cuda(), cur_christoffel_sum.cuda(), 
                                                                        cur_X_sqrt_dirderiv_set.cuda(), other_quantities_at_x = cur_other_quantities_at_x)
                loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
        else:
            optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()
        if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):
            if valIdx is None:
                est_train = gae_P_n_estimate_score(x, X_sqrt, metric_train, model)
                cur_error = gae_P_n_estimate_score_error(x, X_sqrt, est_train, model, model.noise_std**2, 
                                 metricInv_sqrt_train, X_sqrt_dirderiv_set, christoffel_sum_train, 
                                     diagonal_metric=False, other_quantities_at_x = other_quantities_at_x)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = gae_P_n_estimate_score(x_val, X_val_sqrt, metric_val, model)
                cur_error = gae_P_n_estimate_score_error(x_val, X_val_sqrt, est_val, model, model.noise_std**2, 
                                 metricInv_sqrt_val, X_val_sqrt_dirderiv_set, christoffel_sum_val, 
                                     diagonal_metric=False, other_quantities_at_x = other_quantities_at_x_val)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = gae_P_n_estimate_score(x_test, X_test_sqrt, metric_test, model)
                cur_testerror = gae_P_n_estimate_score_error(x_test, X_test_sqrt, est_test, model, model.noise_std**2, 
                             metricInv_sqrt_test, X_test_sqrt_dirderiv_set, christoffel_sum_test, 
                                 diagonal_metric=False, other_quantities_at_x = other_quantities_at_x_test)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = torch.bmm((testscore - est_test).view(-1,1,x.shape[1]), metricInv_sqrt_test).view(-1,x.shape[1])
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
            if expandOutput:
                cur_c, cur_r, cur_drdx, cur_dvdx = model.get_contractive_term(x, X_sqrt, metricInv_train, X_sqrt_dirderiv_set, expandOutput=expandOutput, 
                                                                              other_quantities_at_x = other_quantities_at_x)
                contract_set.append(cur_c)
                recon_set.append(cur_r)
                drdx_sqnorm_set.append(cur_drdx)
                dvdx_sqnorm_set.append(cur_dvdx)
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                 # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            loss = model.calculate_loss(x, X, X_sqrt, X_invsqrt, epsilon, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
            if augment_weight is not None:
                loss1 = model.estimate_score_error(x, X_sqrt, christoffel_sum_train, X_sqrt_dirderiv_set, other_quantities_at_x = other_quantities_at_x)
                print_info("iter: {:d} ---- time {:.1f} ---- loss0: {:f} ---- loss1: {:f}".format(epoch, time.time() - start, float(loss.item()/N), float(loss1.item())), logger)
            else:
                print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, float(loss.item()/N)), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    if returnLastModel:
        best_model = copy.deepcopy(model.state_dict())
    
    if expandOutput:
        contract_set = torch.FloatTensor(contract_set)
        recon_set = torch.FloatTensor(recon_set)
        drdx_sqnorm_set = torch.FloatTensor(drdx_sqnorm_set)
        dvdx_sqnorm_set = torch.FloatTensor(dvdx_sqnorm_set)
        expandedOutputs = (contract_set, recon_set, drdx_sqnorm_set, dvdx_sqnorm_set)
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            if expandOutput:
                return best_model, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset, expandedOutputs
            return best_model, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        if expandOutput:
            return best_model, gscore_est_error_set, gscore_est_error_testset, expandedOutputs
        return best_model, gscore_est_error_set, gscore_est_error_testset
    if expandOutput:
        return best_model, gscore_est_error_set, expandedOutputs
    return best_model, gscore_est_error_set


def grcae_Pn_trainer(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, use_minibatch = True, 
                     scheduler = None, saveAfter = None,
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, 
                 input_mode = None, return_lossTrj = False, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
    
    x = trainDataLoader.dataset.train_data.cuda()
    X = vec2mat(trainDataLoader.dataset.train_data.cuda())
    eps = 1e-14
    S, U = batch_eigsym(X)
    S[S<eps] = eps
    if input_mode == 'LogInput':
        # caution: set x as Log_X for short code... better idea?
        x = mat2vec(Log_mat(X, S = S, U = U))
        
    X_sqrt = trainDataLoader.dataset.train_data_sqrt.cuda()
    X_invsqrt = trainDataLoader.dataset.train_data_invsqrt.cuda()
    
    N = x.shape[0]
    vec_dim = x.shape[1]
    dim = X.shape[1]
    
    ####################
    metric_train = metric_P_n(X)
    metricInv_train = metricInv_P_n(X)
    metricInv_sqrt_train = metricInv_sqrt_P_n(X)
    metricDeriv_train = metricDeriv_P_n(X)
    christoffel_sum_train = christoffelSum_P_n(X)
    tempdir = torch.cuda.FloatTensor(x.shape).zero_()
    X_sqrt_dirderiv_set = torch.cuda.FloatTensor(N, dim, dim, vec_dim).zero_()
    dLog_xdx = torch.cuda.FloatTensor(N, vec_dim, vec_dim).zero_()
    for i in range(vec_dim):
        tempdir[:,i] = 1
        Xdot = vec2mat(tempdir)
        Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
        X_sqrt_dirderiv_set[:,:,:,i] = get_sqrt_sym_DirDeriv(X, Xdot, S = S, U = U, Xdot_trans = Xdot_trans)
        dLog_xdx[:,:,i] = mat2vec(LogDirDeriv(X, Xdot, S = S, U = U, Xdot_trans = Xdot_trans))
        tempdir[:,i] = 0
    
    lossTrj = []
    lossTrj_valset = []
    lossTrj_testset = []
    # variable to store 'estimated' score estimation error (on training data if  valIdx is None, on validation data otherwise)
    gscore_est_error_set = []
    if testdataset is not None:
        x_test = testdataset.train_data.cuda()
        N_test = x_test.shape[0]
        X_test = vec2mat(testdataset.train_data.cuda())
        S_test, U_test = batch_eigsym(X_test)
        S_test[S_test<eps] = eps
        if input_mode == 'LogInput':
            # caution: set x as Log_X for short code... better idea?
            x_test = mat2vec(Log_mat(X_test, S = S_test, U = U_test))
        
        X_test_sqrt = testdataset.train_data_sqrt.cuda()
        X_test_invsqrt = testdataset.train_data_invsqrt.cuda()
        metric_test = metric_P_n(X_test)
        metricInv_test = metricInv_P_n(X_test)
        metricInv_sqrt_test = metricInv_sqrt_P_n(X_test)
        metricDeriv_test = metricDeriv_P_n(X_test)
        christoffel_sum_test = christoffelSum_P_n(X_test)
        X_test_sqrt_dirderiv_set = torch.cuda.FloatTensor(N, dim, dim, vec_dim).zero_()
        dLog_xdx_test = torch.cuda.FloatTensor(N, vec_dim, vec_dim).zero_()
        for i in range(vec_dim):
            tempdir[:,i] = 1
            Xdot = vec2mat(tempdir)
            Xdot_trans = torch.matmul(torch.matmul(U_test.permute(0,2,1), Xdot), U_test)
            X_test_sqrt_dirderiv_set[:,:,:,i] = get_sqrt_sym_DirDeriv(X_test, Xdot, S = S_test, U = U_test, Xdot_trans = Xdot_trans)
            dLog_xdx_test[:,:,i] = mat2vec(LogDirDeriv(X_test, Xdot, S = S_test, U = U_test, Xdot_trans = Xdot_trans))
            tempdir[:,i] = 0
            
        # variable to store 'estimated' score estimation error for test data
        gscore_est_error_testset = []
        if testscore is not None:
            # variable to store score estimation error for test data
            gscore_error_testset = []
    ####################
    
    if not use_gpu:
        x = x.cpu()
        X = X.cpu()
        X_sqrt = X_sqrt.cpu()
        X_invsqrt = X_invsqrt.cpu()
        metric_train = metric_train.cpu()
        metricInv_train = metricInv_train.cpu()
        metricInv_sqrt_train = metricInv_sqrt_train.cpu()
        metricDeriv_train = metricDeriv_train.cpu()
        christoffel_sum_train = christoffel_sum_train.cpu()
        X_sqrt_dirderiv_set = X_sqrt_dirderiv_set.cpu()
        dLog_xdx = dLog_xdx.cpu()
        if testdataset is not None:
            x_test = x_test.cpu()
            X_test = X_test.cpu()
            X_test_sqrt = X_test_sqrt.cpu()
            X_test_invsqrt = X_test_invsqrt.cpu()
            metric_test = metric_test.cpu()
            metricInv_test = metricInv_test.cpu()
            metricInv_sqrt_test = metricInv_sqrt_test.cpu()
            metricDeriv_test = metricDeriv_test.cpu()
            christoffel_sum_test = christoffel_sum_test.cpu()
            X_test_sqrt_dirderiv_set = X_test_sqrt_dirderiv_set.cpu()
            dLog_xdx_test = dLog_xdx_test.cpu()
    
    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        
        x_val = x[valIdx]
        X_val = X[valIdx]
        X_val_sqrt = X_sqrt[valIdx]
        X_val_invsqrt = X_invsqrt[valIdx]
        metric_val = metric_train[valIdx]
        metricInv_val = metricInv_train[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        metricDeriv_val = metricDeriv_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        X_val_sqrt_dirderiv_set = X_sqrt_dirderiv_set[valIdx]
        dLog_xdx_val = dLog_xdx[valIdx]
        N_val = x_val.shape[0]
        
        # train data
        x = x[trainIdx]
        X = X[trainIdx]
        X_sqrt = X_sqrt[trainIdx]
        X_invsqrt = X_invsqrt[trainIdx]
        metric_train = metric_train[trainIdx]
        metricInv_train = metricInv_train[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        metricDeriv_train = metricDeriv_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        X_sqrt_dirderiv_set = X_sqrt_dirderiv_set[trainIdx]
        dLog_xdx = dLog_xdx[trainIdx]
        N = x.shape[0]
        
    if input_mode == 'LogInput':
        other_quantities_at_x = [dLog_xdx]
        if testdataset is not None:
            other_quantities_at_x_test = [dLog_xdx_test]
        if valIdx is not None:
            other_quantities_at_x_val = [dLog_xdx_val]
    else:
        other_quantities_at_x = None
        other_quantities_at_x_test = None
        other_quantities_at_x_val = None
            
    # print initial loss
    loss = model.calculate_loss(x, X_sqrt, metricInv_train, X_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x)
    print_info('initial loss: {:f}'.format(float(loss.item()/N)), logger)
    
    lossTrj.append(loss.item()/N)
    
    if testdataset is not None:
        lossTrj_testset.append(model.calculate_loss(x_test, X_test_sqrt, metricInv_test, X_test_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x_test).item()/N_test)
    if valIdx is not None:
        lossTrj_valset.append(model.calculate_loss(x_val, X_val_sqrt, metricInv_val, X_val_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x_val).item()/N_val)
        
    start = time.time()
    for epoch in range(max_iter_num):
        # to avoid numerical instability, use approximation after enough epoches
        if model.exp_approx not in [1,2,3,4] and model.log_approx not in [1,2,3,4]:
            if float(loss.item()/N) < 0.01 or epoch > 10000:
                model.exp_approx = 3
                model.log_approx = 3
        if use_minibatch:
            for ii, data in enumerate(trainDataLoader, 0):
                _, cur_logx, cur_X_sqrt, _, _, cur_metricInv_train, _, cur_X_sqrt_dirderiv_set, cur_dLog_xdx, _, _, _ = data
                optimizer.zero_grad()
                loss = model.calculate_loss(cur_logx.cuda(), cur_X_sqrt.cuda(), cur_metricInv_train.cuda(), 
                                            X_sqrt_dirderiv_set = cur_X_sqrt_dirderiv_set.cuda(), 
                                            other_quantities_for_loss_at_x = [cur_dLog_xdx.cuda()])
                loss.backward()
                optimizer.step()
                if return_lossTrj:
                    loss_avg = model.calculate_loss(x, X_sqrt, metricInv_train, X_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x).item()/N
                    lossTrj.append(loss_avg)
                    if testdataset is not None:
                        lossTrj_testset.append(model.calculate_loss(x_test, X_test_sqrt, metricInv_test, X_test_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x_test).item()/N_test)
                    if valIdx is not None:
                        lossTrj_valset.append(model.calculate_loss(x_val, X_val_sqrt, metricInv_val, X_val_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x_val).item()/N_val)
                if scheduler is not None:
                    scheduler.step()
        else:
            optimizer.zero_grad()
            loss = model.calculate_loss(x, X_sqrt, metricInv_train, X_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x)
            loss.backward()
            optimizer.step()
            loss_avg = loss.item()/N
            if return_lossTrj:
                lossTrj.append(loss_avg)
                if testdataset is not None:
                    lossTrj_testset.append(model.calculate_loss(x_test, X_test_sqrt, metricInv_test, X_test_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x_test).item()/N_test)
                if valIdx is not None:
                    lossTrj_valset.append(model.calculate_loss(x_val, X_val_sqrt, metricInv_val, X_val_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x_val).item()/N_val)
            if scheduler is not None:
                scheduler.step()
        if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):
            if valIdx is None:
                est_train = gae_P_n_estimate_score(x, X_sqrt, metric_train, model)
                cur_error = gae_P_n_estimate_score_error(x, X_sqrt, est_train, model, model.noise_std**2, 
                                 metricInv_sqrt_train, X_sqrt_dirderiv_set, christoffel_sum_train, 
                                     diagonal_metric=False, other_quantities_at_x = other_quantities_at_x)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = gae_P_n_estimate_score(x_val, X_val_sqrt, metric_val, model)
                cur_error = gae_P_n_estimate_score_error(x_val, X_val_sqrt, est_val, model, model.noise_std**2, 
                                 metricInv_sqrt_val, X_val_sqrt_dirderiv_set, christoffel_sum_val, 
                                     diagonal_metric=False, other_quantities_at_x = other_quantities_at_x_val)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = gae_P_n_estimate_score(x_test, X_test_sqrt, metric_test, model)
                cur_testerror = gae_P_n_estimate_score_error(x_test, X_test_sqrt, est_test, model, model.noise_std**2, 
                             metricInv_sqrt_test, X_test_sqrt_dirderiv_set, christoffel_sum_test, 
                                 diagonal_metric=False, other_quantities_at_x = other_quantities_at_x_test)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = torch.bmm((testscore - est_test).view(-1,1,x.shape[1]), metricInv_sqrt_test).view(-1,x.shape[1])
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                 # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            if not return_lossTrj and use_minibatch:
                loss_avg = model.calculate_loss(x, X_sqrt, metricInv_train, X_sqrt_dirderiv_set, 
                                                   other_quantities_for_loss_at_x = other_quantities_at_x).item()/N
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, loss_avg), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    lossTrjs = (lossTrj, lossTrj_valset, lossTrj_testset)
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset
    return best_model, lossTrjs, gscore_est_error_set
