import numpy as np
import time
import copy
import torch

def printTrainingStatus(epoch, start, loss):
    print("")
    print("epoch: %d ---- time %.2f" % (epoch, time.time() - start))
    print("loss: %f" % loss)

def dae_trainer_batchall(trainDataLoader, model, optimizer, scheduler, max_iter_num, use_gpu = True, 
                saveModel = False, printEpochPeriod = 1000, useFixedNoise = True):
    N = trainDataLoader.dataset.train_data.size()[0]
    x = trainDataLoader.dataset.train_data.clone()
    if use_gpu:
        x = x.cuda()
    
    if useFixedNoise:
        # sample fixed noise
        if use_gpu:
            if isinstance(model.noise_std, float):
                epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, model.noise_std)
            else:
                epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, 1.0) * model.noise_std.cuda()
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, 1.0) * model.noise_std
        
        # print initial loss
        loss = model.calculate_loss(x, epsilon)
        print('initial loss: %f' % (loss.data/N))
    
    def closure():
        optimizer.zero_grad()
        if useFixedNoise:
            loss = model.calculate_loss(x, epsilon)
        else:
            loss = model.calculate_loss(x)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            print("iter: %d ---- time %.1f" % (epoch, time.time() - start))
            if useFixedNoise:
                loss = model.calculate_loss(x, epsilon)
            else:
                loss = model.calculate_loss(x)
            print('loss: %f' % (loss.data/N))
        scheduler.step()
    if saveModel:
        # deep copy the model
        model_wts = copy.deepcopy(model.state_dict())
        return model_wts
    return loss.data/N

def gae_N_n_trainer_batchall(dtiDataLoader, model, optimizer, scheduler, max_iter_num, use_gpu = True, 
                saveModel = False, printEpochPeriod = 1000, weight_mode = None):
    N = dtiDataLoader.dataset.posAndCov.size()[0]
    x = dtiDataLoader.dataset.posAndCov.clone()
    x_weight = None
    covInv_sqrt = dtiDataLoader.dataset.covInv_sqrt.clone()
    cov_sqrt = dtiDataLoader.dataset.cov_sqrt.clone()
    cov_logJacobian = dtiDataLoader.dataset.logJacobian
    cov_eigvec = dtiDataLoader.dataset.cov_eigvec
    cov_eigval = dtiDataLoader.dataset.cov_eigval
    if use_gpu:
        x = x.cuda()
        if model.use_logvec_input:
            cov_logJacobian = cov_logJacobian.clone().cuda()
        covInv_sqrt = covInv_sqrt.cuda()
        cov_sqrt = cov_sqrt.cuda()
        cov_eigvec = cov_eigvec.cuda()
        cov_eigval = cov_eigval.cuda()
    
    
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(x, cov_sqrt, covInv_sqrt, cov_logJacobian, cov_eigvec, cov_eigval, weight = x_weight)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            print("iter: %d ---- time %.1f" % (epoch, time.time() - start))
            loss = model.calculate_loss(x, cov_sqrt, covInv_sqrt, cov_logJacobian, cov_eigvec, cov_eigval)
            print('loss: %f' % (loss.data/N))
        scheduler.step()
    if saveModel:
        # deep copy the model
        model_wts = copy.deepcopy(model.state_dict())
        return model_wts
    return loss.data/N