import time
import copy
import torch
from sph_n_DataUtil import *
from gae_score_estimation import *
from gae_sph_n_ambient_score_estimation import *
from util import *

def get_indices_to_use(data, bound):
    data_max = data.clone()
    data_min = data.clone()
    for i in range(len(data.shape)-1):
        data_max, _ = torch.max(data_max, dim=1)
    for i in range(len(data.shape)-1):
        data_min, _ = torch.min(data_min, dim=1)
    useIdx = torch.logical_and(data_min > -bound, data_max < bound)
    return useIdx

def dae_trainer_batchall(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, scheduler = None, saveAfter = None, 
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    x = trainDataLoader.dataset.train_data
    N = x.shape[0]
    
    ####################
    traininput = trainDataLoader.dataset.train_data
        
    score_est_error_set = []
    if testdataset is not None:
        testinput = testdataset.train_data
        score_est_error_testset = []
        if testscore is not None:
            score_error_testset = []
    ####################
    
    # sample fixed noise
    epsilon = torch.FloatTensor(x.size()).normal_(0.0, model.noise_std)
    if use_gpu:
        x = x.cuda()
        traininput = traininput.cuda()
        epsilon = epsilon.cuda()
        if testdataset is not None:
            testinput = testinput.cuda()    
        
    # print initial loss
    loss = model.calculate_loss(x, epsilon)
    print('initial loss:', loss.item()/N)
    
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(x)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()
        if epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1:
            est_train = dae_estimate_score(traininput, model)
            cur_error = dae_estimate_score_error(traininput, est_train, model, model.noise_std**2)
            score_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = dae_estimate_score(testinput, model)
                cur_testerror = dae_estimate_score_error(testinput, est_test, model, model.noise_std**2)
                score_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = (testscore - est_test)
                    score_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
                    
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = score_est_error_set[-1]
                min_epoch = epoch
            elif score_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = score_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            loss = model.calculate_loss(x, epsilon)
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, float(loss.item()/N)), logger)
    score_est_error_set = torch.FloatTensor(score_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    if testdataset is not None:
        score_est_error_testset = torch.FloatTensor(score_est_error_testset)
        if testscore is not None:
            score_error_testset = torch.FloatTensor(score_error_testset)
            return best_model, score_est_error_set, score_est_error_testset, score_error_testset
        return best_model, score_est_error_set, score_est_error_testset
    return best_model, score_est_error_set


def dae_Sn_trainer_batchall(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, scheduler = None, saveAfter = None, 
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    ####################
    traininput = trainDataLoader.dataset.train_data
    metricInv_sqrt_train = metricInvSqrt_torch(traininput)
    metricInv_train = metricInv_sqrt_train**2
    metricInvDeriv_train = metricInvDeriv_torch(traininput)
    christoffel_sum_train = christoffelSum_torch(traininput)
    christoffel_sumDeriv_train = christoffelSumDeriv_torch(traininput)
    N = traininput.shape[0]
    
    # determin indices to use in estimating score error
    ub = 1e5
    lb = -ub
    useIdx_train = get_indices_to_use(metricInv_train, ub)

    gscore_est_error_set = []
    if testdataset is not None:
        testinput = testdataset.train_data
        metricInv_sqrt_test = metricInvSqrt_torch(testinput)
        metricInv_test = metricInv_sqrt_test**2
        metricInvDeriv_test = metricInvDeriv_torch(testinput)
        christoffel_sum_test = christoffelSum_torch(testinput)
        christoffel_sumDeriv_test = christoffelSumDeriv_torch(testinput)
        useIdx_test = get_indices_to_use(metricInv_test, ub)
        gscore_est_error_testset = []
        if testscore is not None:
            gscore_error_testset = []
    ####################
    # sample fixed noise
    epsilon = torch.FloatTensor(traininput.size()).normal_(0.0, model.noise_std)
    if use_gpu:
        traininput = traininput.cuda()
        metricInv_sqrt_train = metricInv_sqrt_train.cuda()
        metricInv_train = metricInv_train.cuda()
        metricInvDeriv_train = metricInvDeriv_train.cuda()
        christoffel_sum_train = christoffel_sum_train.cuda()
        christoffel_sumDeriv_train = christoffel_sumDeriv_train.cuda()
        epsilon = epsilon.cuda()
        if testdataset is not None:
            testinput = testinput.cuda()
            metricInv_sqrt_test = metricInv_sqrt_test.cuda()
            metricInv_test = metricInv_test.cuda()
            metricInvDeriv_test = metricInvDeriv_test.cuda()
            christoffel_sum_test = christoffel_sum_test.cuda()
            christoffel_sumDeriv_test = christoffel_sumDeriv_test.cuda()
    
    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        valinput = traininput[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        metricInv_val = metricInv_train[valIdx]
        metricInvDeriv_val = metricInvDeriv_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        christoffel_sumDeriv_val = christoffel_sumDeriv_train[valIdx]
        useIdx_val = useIdx_train[valIdx]
        
        # train data
        traininput = traininput[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        metricInv_train = metricInv_train[trainIdx]
        metricInvDeriv_train = metricInvDeriv_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        christoffel_sumDeriv_train = christoffel_sumDeriv_train[trainIdx]
        useIdx_train = useIdx_train[trainIdx]
        N = traininput.shape[0]
        epsilon = epsilon[trainIdx]
    
    # print initial loss
    loss = model.calculate_loss(traininput, epsilon)
    print('initial loss:', loss.item()/N)
    
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(traininput)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()
        if epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1:
            if valIdx is None:
                est_train = dae_estimate_score(traininput, model) - christoffel_sum_train
                estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_train
                cur_error = estimate_gscore_error(est_train[useIdx_train], estDeriv_train[useIdx_train], metricInv_train[useIdx_train], 
                                                  metricInv_sqrt_train[useIdx_train], metricInvDeriv_train[useIdx_train], 
                                                  christoffel_sum_train[useIdx_train], diagonal_metric=True)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = dae_estimate_score(valinput, model) - christoffel_sum_val
                estDeriv_val = dae_estimate_score_deriv(valinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_val
                cur_error = estimate_gscore_error(est_val[useIdx_val], estDeriv_val[useIdx_val], metricInv_val[useIdx_val], 
                                                  metricInv_sqrt_val[useIdx_val], metricInvDeriv_val[useIdx_val], 
                                                  christoffel_sum_val[useIdx_val], diagonal_metric=True)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = dae_estimate_score(testinput, model) - christoffel_sum_test
                estDeriv_test = dae_estimate_score_deriv(testinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_test
                cur_testerror = estimate_gscore_error(est_test[useIdx_test], estDeriv_test[useIdx_test], metricInv_test[useIdx_test], 
                metricInv_sqrt_test[useIdx_test], metricInvDeriv_test[useIdx_test], christoffel_sum_test[useIdx_test], diagonal_metric=True)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = (testscore - est_test) * metricInv_sqrt_test
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
                    
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            loss = model.calculate_loss(traininput, epsilon)
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, float(loss.item()/N)), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, gscore_est_error_set, gscore_est_error_testset
    return best_model, gscore_est_error_set

def rcae_Sn_trainer(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, use_minibatch = True, scheduler = None, saveAfter = None, 
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, 
                    return_lossTrj = False, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
        
    ####################
    traininput = trainDataLoader.dataset.train_data
    metricInv_sqrt_train = metricInvSqrt_torch(traininput)
    metricInv_train = metricInv_sqrt_train**2
    metricInvDeriv_train = metricInvDeriv_torch(traininput)
    christoffel_sum_train = christoffelSum_torch(traininput)
    christoffel_sumDeriv_train = christoffelSumDeriv_torch(traininput)
    N = traininput.shape[0]
    
    # determin indices to use in estimating score error
    ub = 1e5
    lb = -ub
    useIdx_train = get_indices_to_use(metricInv_train, ub)
    
    lossTrj = []
    lossTrj_valset = []
    lossTrj_testset = []
    gscore_est_error_set = []
    if testdataset is not None:
        testinput = testdataset.train_data
        N_test = testinput.shape[0]
        metricInv_sqrt_test = metricInvSqrt_torch(testinput)
        metricInv_test = metricInv_sqrt_test**2
        metricInvDeriv_test = metricInvDeriv_torch(testinput)
        christoffel_sum_test = christoffelSum_torch(testinput)
        christoffel_sumDeriv_test = christoffelSumDeriv_torch(testinput)
        useIdx_test = get_indices_to_use(metricInv_test, ub)
        gscore_est_error_testset = []
        if testscore is not None:
            gscore_error_testset = []
    ####################
    
    if use_gpu:
        traininput = traininput.cuda()
        metricInv_sqrt_train = metricInv_sqrt_train.cuda()
        metricInv_train = metricInv_train.cuda()
        metricInvDeriv_train = metricInvDeriv_train.cuda()
        christoffel_sum_train = christoffel_sum_train.cuda()
        christoffel_sumDeriv_train = christoffel_sumDeriv_train.cuda()
        if testdataset is not None:
            testinput = testinput.cuda()
            metricInv_sqrt_test = metricInv_sqrt_test.cuda()
            metricInv_test = metricInv_test.cuda()
            metricInvDeriv_test = metricInvDeriv_test.cuda()
            christoffel_sum_test = christoffel_sum_test.cuda()
            christoffel_sumDeriv_test = christoffel_sumDeriv_test.cuda()
    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        valinput = traininput[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        metricInv_val = metricInv_train[valIdx]
        metricInvDeriv_val = metricInvDeriv_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        christoffel_sumDeriv_val = christoffel_sumDeriv_train[valIdx]
        useIdx_val = useIdx_train[valIdx]
        N_val = valinput.shape[0]
        
        # train data
        traininput = traininput[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        metricInv_train = metricInv_train[trainIdx]
        metricInvDeriv_train = metricInvDeriv_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        christoffel_sumDeriv_train = christoffel_sumDeriv_train[trainIdx]
        useIdx_train = useIdx_train[trainIdx]
        N = traininput.shape[0]
    
    # print initial loss
    loss = model.calculate_loss(traininput)
    print_info('initial loss: {:f}'.format(loss.item()/N), logger)
    lossTrj.append(loss.item()/N)
    if testdataset is not None:
        lossTrj_testset.append(model.calculate_loss(testinput).item()/N_test)
    
    start = time.time()
    for epoch in range(max_iter_num):
        if use_minibatch:
            for ii, (data, _, _) in enumerate(trainDataLoader, 0):
                optimizer.zero_grad()
                loss = model.calculate_loss(data.cuda())
                loss.backward()
                optimizer.step()
                if return_lossTrj:
                    loss_avg = model.calculate_loss(traininput).item()/N
                    lossTrj.append(loss_avg)
                    if testdataset is not None:
                        lossTrj_testset.append(model.calculate_loss(testinput).item()/N_test)
                    if valIdx is not None:
                        lossTrj_valset.append(model.calculate_loss(valinput).item()/N_val)
                if scheduler is not None:
                    scheduler.step()
        else:
            optimizer.zero_grad()
            loss = model.calculate_loss(traininput)
            loss.backward()
            optimizer.step()
            loss_avg = loss.item()/N
            if return_lossTrj:
                lossTrj.append(loss_avg)
                if testdataset is not None:
                    lossTrj_testset.append(model.calculate_loss(testinput).item()/N_test)
                if valIdx is not None:
                    lossTrj_valset.append(model.calculate_loss(valinput).item()/N_val)
            if scheduler is not None:
                scheduler.step()
        if epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1:
            if valIdx is None:
                est_train = dae_estimate_score(traininput, model) - christoffel_sum_train
                estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_train
                cur_error = estimate_gscore_error(est_train[useIdx_train], estDeriv_train[useIdx_train], metricInv_train[useIdx_train], 
                                                  metricInv_sqrt_train[useIdx_train], metricInvDeriv_train[useIdx_train], 
                                                  christoffel_sum_train[useIdx_train], diagonal_metric=True)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = dae_estimate_score(valinput, model) - christoffel_sum_val
                estDeriv_val = dae_estimate_score_deriv(valinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_val
                cur_error = estimate_gscore_error(est_val[useIdx_val], estDeriv_val[useIdx_val], metricInv_val[useIdx_val], 
                                                  metricInv_sqrt_val[useIdx_val], metricInvDeriv_val[useIdx_val], 
                                                  christoffel_sum_val[useIdx_val], diagonal_metric=True)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = dae_estimate_score(testinput, model) - christoffel_sum_test
                estDeriv_test = dae_estimate_score_deriv(testinput, model, model.noise_std**2, force_cpu=False) \
                - christoffel_sumDeriv_test
                cur_testerror = estimate_gscore_error(est_test[useIdx_test], estDeriv_test[useIdx_test], metricInv_test[useIdx_test], 
                metricInv_sqrt_test[useIdx_test], metricInvDeriv_test[useIdx_test], christoffel_sum_test[useIdx_test], diagonal_metric=True)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = (testscore - est_test) * metricInv_sqrt_test
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
                    
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            if not return_lossTrj and use_minibatch:
                loss_avg = model.calculate_loss(traininput).item()/N
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, loss_avg), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    lossTrjs = (lossTrj, lossTrj_valset, lossTrj_testset)
    
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset
    return best_model, lossTrjs, gscore_est_error_set
        
def gae_SnAmb_trainer_batchall(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, scheduler = None, saveAfter = None,
                printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
    
    x = trainDataLoader.dataset.train_data
    N = x.shape[0]
    
    ####################
    trainPosInput = trainDataLoader.dataset.train_data
    traininput = getCoord_torch(x)
    metricInv_sqrt_train = metricInvSqrt_torch(traininput)
    christoffel_sum_train = christoffelSum_torch(traininput)
    dx_dxth_train = getPosJacobianFromPos_torch(trainPosInput, eps=1e-6)
    
    if traininput.isnan().sum() > 0:
        raise Exception("nan included in traininput")
    if traininput.isinf().sum() > 0:
        raise Exception("inf included in traininput")
    
    # variable to store 'estimated' score estimation error
    gscore_est_error_set = []
    if testdataset is not None:
        testPosInput = testdataset.train_data
        testinput = getCoord_torch(testdataset.train_data)
        metricInv_sqrt_test = metricInvSqrt_torch(testinput)
        christoffel_sum_test = christoffelSum_torch(testinput)
        dx_dxth_test = getPosJacobianFromPos_torch(testPosInput, eps=1e-6)
        
        if testinput.isnan().sum() > 0:
            raise Exception("nan included in testinput")
        if testinput.isinf().sum() > 0:
            raise Exception("inf included in testinput")
        
        # variable to store 'estimated' score estimation error for test data
        gscore_est_error_testset = []
        if testscore is not None:
            # variable to store score estimation error for test data
            gscore_error_testset = []
    ####################
    
    # sample fixed noise
    epsilon = torch.FloatTensor(x.size()[0],x.size()[1]).normal_(0.0, model.noise_std)
    if use_gpu:
        x = x.cuda()
        trainPosInput = trainPosInput.cuda()
        metricInv_sqrt_train = metricInv_sqrt_train.cuda()
        christoffel_sum_train = christoffel_sum_train.cuda()
        dx_dxth_train = dx_dxth_train.cuda()
        epsilon = epsilon.cuda()
        if testdataset is not None:
            testPosInput = testPosInput.cuda()
            metricInv_sqrt_test = metricInv_sqrt_test.cuda()
            christoffel_sum_test = christoffel_sum_test.cuda()
            dx_dxth_test = dx_dxth_test.cuda()
            
    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        valPosInput = trainPosInput[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        dx_dxth_val = dx_dxth_train[valIdx]
        
        # train data
        x = x[trainIdx]
        trainPosInput = trainPosInput[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        dx_dxth_train = dx_dxth_train[trainIdx]
        N = x.shape[0]
        epsilon = epsilon[trainIdx]
            
    # print initial loss
    loss = model.calculate_loss(x, epsilon)
    print_info('initial loss: {:f}'.format(float(loss.item()/N)), logger)
    
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(x)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()
        if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):
            if valIdx is None:
                est_train = gae_sph_n_amb_estimate_score(trainPosInput, model, dx_dxth_train)
                cur_error = gae_sph_n_amb_estimate_score_error(trainPosInput, est_train, model, 
                                 metricInv_sqrt_train, christoffel_sum_train, 
                                     dx_dxth = dx_dxth_train)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = gae_sph_n_amb_estimate_score(valPosInput, model, dx_dxth_val)
                cur_error = gae_sph_n_amb_estimate_score_error(valPosInput, est_val, model, 
                                 metricInv_sqrt_val, christoffel_sum_val, 
                                     dx_dxth = dx_dxth_val)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = gae_sph_n_amb_estimate_score(testPosInput, model, dx_dxth_test)
                cur_testerror = gae_sph_n_amb_estimate_score_error(testPosInput, est_test, model, 
                             metricInv_sqrt_test, christoffel_sum_test, 
                                 dx_dxth = dx_dxth_test)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = (testscore - est_test) * metricInv_sqrt_test
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
           
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            loss = model.calculate_loss(x, epsilon)
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, float(loss.item()/N)), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, gscore_est_error_set, gscore_est_error_testset
    return best_model, gscore_est_error_set

def grcae_SnAmb_trainer(trainDataLoader, model, optimizer, max_iter_num, use_gpu = True, use_minibatch = True, 
                        scheduler = None, saveAfter = None,
                        printEpochPeriod = 1000, checkEstErrorPeriod = 20, testdataset = None, testscore = None, valIdx = None, 
                        return_lossTrj = False, loggingFileName = None):
    if loggingFileName is None:
        logger = None
    else:
        logger = set_logger(loggingFileName)
    
    x = trainDataLoader.dataset.train_data
    N = x.shape[0]
    
    ####################
    trainPosInput = trainDataLoader.dataset.train_data
    traininput = getCoord_torch(x)
    metricInv_sqrt_train = metricInvSqrt_torch(traininput)
    christoffel_sum_train = christoffelSum_torch(traininput)
    dx_dxth_train = getPosJacobianFromPos_torch(trainPosInput, eps=1e-6)
    
    if traininput.isnan().sum() > 0:
        raise Exception("nan included in traininput")
    if traininput.isinf().sum() > 0:
        raise Exception("inf included in traininput")
    
    lossTrj = []
    lossTrj_valset = []
    lossTrj_testset = []
    # variable to store 'estimated' score estimation error
    gscore_est_error_set = []
    if testdataset is not None:
        testPosInput = testdataset.train_data
        N_test = testPosInput.shape[0]
        testinput = getCoord_torch(testdataset.train_data)
        metricInv_sqrt_test = metricInvSqrt_torch(testinput)
        christoffel_sum_test = christoffelSum_torch(testinput)
        dx_dxth_test = getPosJacobianFromPos_torch(testPosInput, eps=1e-6)
        
        if testinput.isnan().sum() > 0:
            raise Exception("nan included in testinput")
        if testinput.isinf().sum() > 0:
            raise Exception("inf included in testinput")
        
        # variable to store 'estimated' score estimation error for test data
        gscore_est_error_testset = []
        if testscore is not None:
            # variable to store score estimation error for test data
            gscore_error_testset = []
    ####################
    
    # sample fixed noise
    if use_gpu:
        x = x.cuda()
        trainPosInput = trainPosInput.cuda()
        metricInv_sqrt_train = metricInv_sqrt_train.cuda()
        christoffel_sum_train = christoffel_sum_train.cuda()
        dx_dxth_train = dx_dxth_train.cuda()
        if testdataset is not None:
            testPosInput = testPosInput.cuda()
            metricInv_sqrt_test = metricInv_sqrt_test.cuda()
            christoffel_sum_test = christoffel_sum_test.cuda()
            dx_dxth_test = dx_dxth_test.cuda()
            
    if valIdx is not None:
        # split train_data into training set and validation set
        trainIdx = torch.arange(N)
        for v in valIdx:
            trainIdx = trainIdx[trainIdx!=v]
        valPosInput = trainPosInput[valIdx]
        metricInv_sqrt_val = metricInv_sqrt_train[valIdx]
        christoffel_sum_val = christoffel_sum_train[valIdx]
        dx_dxth_val = dx_dxth_train[valIdx]
        N_val = valPosInput.shape[0]
        
        # train data
        x = x[trainIdx]
        trainPosInput = trainPosInput[trainIdx]
        metricInv_sqrt_train = metricInv_sqrt_train[trainIdx]
        christoffel_sum_train = christoffel_sum_train[trainIdx]
        dx_dxth_train = dx_dxth_train[trainIdx]
        N = x.shape[0]
        
    # print initial loss
    loss = model.calculate_loss(x)
    print_info('initial loss: {:f}'.format(float(loss.item()/N)), logger)
    lossTrj.append(loss.item()/N)
    
    if testdataset is not None:
        lossTrj_testset.append(model.calculate_loss(testPosInput).item()/N_test)
    if valIdx is not None:
        lossTrj_valset.append(model.calculate_loss(valPosInput).item()/N_val)
        
    start = time.time()
    for epoch in range(max_iter_num):
        if use_minibatch:
            for ii, data in enumerate(trainDataLoader, 0):
                optimizer.zero_grad()
                loss = model.calculate_loss(data.cuda())
                loss.backward()
                optimizer.step()
                if return_lossTrj:
                    loss_avg = model.calculate_loss(x).item()/N
                    lossTrj.append(loss_avg)
                    if testdataset is not None:
                        lossTrj_testset.append(model.calculate_loss(testPosInput).item()/N_test)
                    if valIdx is not None:
                        lossTrj_valset.append(model.calculate_loss(valPosInput).item()/N_val)
                if scheduler is not None:
                    scheduler.step()
        else:
            optimizer.zero_grad()
            loss = model.calculate_loss(x)
            loss.backward()
            optimizer.step()
            loss_avg = loss.item()/N
            if return_lossTrj:
                lossTrj.append(loss_avg)
                if testdataset is not None:
                    lossTrj_testset.append(model.calculate_loss(testPosInput).item()/N_test)
                if valIdx is not None:
                    lossTrj_valset.append(model.calculate_loss(valPosInput).item()/N_val)
            if scheduler is not None:
                scheduler.step()
        if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):
            if valIdx is None:
                est_train = gae_sph_n_amb_estimate_score(trainPosInput, model, dx_dxth_train)
                cur_error = gae_sph_n_amb_estimate_score_error(trainPosInput, est_train, model, 
                                 metricInv_sqrt_train, christoffel_sum_train, 
                                     dx_dxth = dx_dxth_train)
                gscore_est_error_set.append(cur_error)
            else:
                # if valIdx is not None, use 'estimated' score estimation error on validation set for model selection
                est_val = gae_sph_n_amb_estimate_score(valPosInput, model, dx_dxth_val)
                cur_error = gae_sph_n_amb_estimate_score_error(valPosInput, est_val, model, 
                                 metricInv_sqrt_val, christoffel_sum_val, 
                                     dx_dxth = dx_dxth_val)
                gscore_est_error_set.append(cur_error)
            if testdataset is not None:
                est_test = gae_sph_n_amb_estimate_score(testPosInput, model, dx_dxth_test)
                cur_testerror = gae_sph_n_amb_estimate_score_error(testPosInput, est_test, model, 
                             metricInv_sqrt_test, christoffel_sum_test, 
                                 dx_dxth = dx_dxth_test)
                gscore_est_error_testset.append(cur_testerror)
                if testscore is not None:
                    diff = (testscore - est_test) * metricInv_sqrt_test
                    gscore_error_testset.append(torch.mean(torch.sum(diff*diff, dim = 1)).cpu())
           
            if epoch == 0:
                best_model = copy.deepcopy(model.state_dict())
                min_val = gscore_est_error_set[-1]
                min_epoch = epoch
            elif gscore_est_error_set[-1] <= min_val:
                # save models after sufficient iterations
                if saveAfter is None or (saveAfter is not None and epoch > saveAfter):
                    best_model = copy.deepcopy(model.state_dict())
                    min_val = gscore_est_error_set[-1]
                    min_epoch = epoch
            
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            if not return_lossTrj and use_minibatch:
                loss_avg = model.calculate_loss(x).item()/N
            print_info("iter: {:d} ---- time {:.1f} ---- loss: {:f}".format(epoch, time.time() - start, loss_avg), logger)
    gscore_est_error_set = torch.FloatTensor(gscore_est_error_set)
    print_info("min. estimated score error: {:f}, min. epoch: {:d}".format(min_val, min_epoch), logger)
    
    lossTrjs = (lossTrj, lossTrj_valset, lossTrj_testset)
    if testdataset is not None:
        gscore_est_error_testset = torch.FloatTensor(gscore_est_error_testset)
        if testscore is not None:
            gscore_error_testset = torch.FloatTensor(gscore_error_testset)
            return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset, gscore_error_testset
        return best_model, lossTrjs, gscore_est_error_set, gscore_est_error_testset
    return best_model, lossTrjs, gscore_est_error_set