from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from transfer_data import train_test_sets_generation, train_test_sets_generation_ijcnn1
from sklearn.metrics import roc_auc_score

import os
from itertools import count
import time
import random
import numpy as np
import math

from models.models import *
from models.preact_resnet import *

from torchvision.utils import save_image

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

if not torch.cuda.is_available():
    print('cuda is required but cuda is not available')
    exit()

#== parser start
parser = argparse.ArgumentParser(description='PyTorch')
# base setting 1: fixed
parser.add_argument('--job-id', type=int, default=1)
parser.add_argument('--seed', type=int, default=10000)
# base setting 2: fixed
parser.add_argument('--test-batch-size', type=int, default=9998)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--data-path', type=str, default='./dataset/')                    
# experiment setting
parser.add_argument('--dataset', type=str, default='ijcnn1') #mnist_binary, mnist
parser.add_argument('--data-aug', type=int, default=0) 
parser.add_argument('--model', type=str, default='AUC_linear') #LeNet AUC_LeNet AUC_MLP AUC_linear
# method setting
parser.add_argument('--lr', type=float, default=300)
parser.add_argument('--lr_c', type=float, default=300)
parser.add_argument('--proj_wab', type=float, default=1000)
parser.add_argument('--proj_c', type=float, default=1000)
parser.add_argument('--batch-size', type=int, default=512)
parser.add_argument('--ssize', type=int, default=64)
parser.add_argument('--method', type=int, default=0) 
                    # --method=0: standard
                    # --method=1: q-SGD 
args = parser.parse_args()                    
#== parser end
data_path = args.data_path + args.dataset
if not os.path.isdir(data_path):
    os.makedirs(data_path)

result_path = './results/'    
if not os.path.isdir(result_path):
    os.makedirs(result_path)
result_path += args.dataset + '_' + str(args.data_aug) + '_' + args.model
result_path += '_' + str(args.method) + '_' + str(args.batch_size)
if args.method != 0:
    result_path += '_' + str(args.ssize) 
result_path += '_' + str(args.job_id)
filep = open(result_path + '.txt', 'w')
with open(__file__) as f: 
    filep.write('\n'.join(f.read().split('\n')[1:]))
filep.write('\n\n')    

out_str = str(args)
print(out_str)
filep.write(out_str + '\n') 

if args.seed is None:
  # args.seed = random.randint(1, 10000)
  args.seed = 605
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.enabled = True

out_str = 'initial seed = ' + str(args.seed)
print(out_str)
filep.write(out_str + '\n\n')

#===============================================================
#=== dataset setting
#===============================================================
kwargs = {}
train_transform = transforms.Compose([transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_Sampler = None
test_Sampler = None
Shuffle = True
if args.dataset == 'mnist':
    nh = 28
    nw = 28
    nc = 1
    num_class = 10
    end_epoch = 50
    if args.data_aug == 1:        
        end_epoch = 200 
        train_transform = transforms.Compose([
                            transforms.RandomCrop(28, padding=2),
                            transforms.RandomAffine(15, scale=(0.85, 1.15)),
                            transforms.ToTensor()       
                       ])                
    train_data = datasets.MNIST(data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.MNIST(data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'cifar10':
    nh = 32
    nw = 32
    nc = 3
    num_class = 10 
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200 
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'cifar100':
    nh = 32
    nw = 32
    nc = 3
    num_class = 100
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200    
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])
    train_data = datasets.CIFAR100(root=data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.CIFAR100(root=data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'svhn':
    nh = 32
    nw = 32
    nc = 3
    num_class = 10
    end_epoch = 50    
    if args.data_aug == 1:
        end_epoch = 200
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])
    train_data = datasets.SVHN(data_path, split='train', download=True, transform=train_transform)
    test_data = datasets.SVHN(data_path, split='test', download=True, transform=test_transform)    
elif args.dataset == 'fashionmnist':
    nh = 28
    nw = 28
    nc = 1
    num_class = 10
    end_epoch = 20
    if args.data_aug == 1:
        end_epoch = 200       
        train_transform = transforms.Compose([
            transforms.RandomCrop(28, padding=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]) 
    train_data = datasets.FashionMNIST(data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.FashionMNIST(data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'kmnist':
    nh = 28
    nw = 28
    nc = 1
    num_class = 10
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200
        train_transform = transforms.Compose([
                            transforms.RandomCrop(28, padding=2),
                            transforms.ToTensor()       
                       ])
    train_data = datasets.KMNIST(data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.KMNIST(data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'semeion':
    nh = 16
    nw = 16
    nc = 1
    num_class = 10 # the digits from 0 to 9 (written by 80 people twice)    
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200
        train_transform = transforms.Compose([
            transforms.RandomCrop(16, padding=1),
            transforms.RandomAffine(4, scale=(1.05, 1.05)),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])    
    train_data = datasets.SEMEION(data_path, transform=train_transform, download=True) 
    test_data = train_data    
    random_index = np.load(data_path+'/random_index.npy')
    train_size = 1000    
    train_Sampler = SubsetRandomSampler(random_index[range(train_size)])
    test_Sampler = SubsetRandomSampler(random_index[range(train_size,len(test_data))])
    Shuffle = False
elif args.dataset == 'fakedata':
    nh = 24
    nw = 24
    nc = 3
    num_class = 10  
    end_epoch = 50   
    train_size = 1000
    test_size = 1000
    train_data = datasets.FakeData(size=train_size+test_size, image_size=(nc, nh, nw), num_classes=num_class, transform=train_transform)
    test_data  = train_data 
    train_Sampler = SubsetRandomSampler(range(train_size))
    test_Sampler = SubsetRandomSampler(range(train_size,len(test_data)))
    Shuffle = False
elif args.dataset == 'mnist_binary':
    nh = 28
    nw = 28
    nc = 1
    num_class = 1
    end_epoch = 50
    train_data, test_data, n_train_pos, n_train_neg = train_test_sets_generation(args.dataset, data_path)
    Shuffle = False
elif args.dataset == 'ijcnn1':
    nh = 1
    nw = 1
    nc = 22
    num_class = 1
    end_epoch = 1000
    train_data, test_data, n_train_pos, n_train_neg = train_test_sets_generation_ijcnn1(args.dataset, data_path, train_ratio=0.8)
    Shuffle = False
else: 
    print('specify dataset')
    exit()   
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,      sampler=train_Sampler, shuffle=Shuffle, **kwargs)
test_loader  = torch.utils.data.DataLoader(test_data,  batch_size=args.test_batch_size, sampler=test_Sampler,  shuffle=False,   **kwargs)

#===============================================================
#=== model setting
#===============================================================
if args.model == 'LeNet':
    model = LeNet(nc, nh, nw, num_class).cuda()
elif args.model == 'PreActResNet18':
    model = PreActResNet18(nc, num_class).cuda()
elif args.model == 'Linear' or args.model == 'SVM':
    dx = nh * nw * nc     
    model = Linear(dx, num_class).cuda()
elif args.model == 'AUC_LeNet':
    model = LeNet(nc, nh, nw, num_class).cuda()
elif args.model == 'AUC_MLP':
    dx = nh * nw * nc
    model = MLP(dx, num_class).cuda()
elif args.model == 'AUC_linear':
    dx = nh * nw * nc
    model = Linear(dx, num_class, False).cuda()
    model_test = Linear(dx, num_class, False).cuda()
else:
    print('specify model')
    exit() 
    
#===============================================================
#=== utils def
#===============================================================
def lr_decay_func(optimizer, lr_decay=0.1):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= 0.1   
    return optimizer    
def lr_scheduler(optimizer, epoch, lr_decay=0.1, interval=10):
    # if args.data_aug == 0:
        # if epoch == 10 or epoch == 50:
        #     optimizer = lr_decay_func(optimizer, lr_decay=lr_decay)
        # if epoch == 50:
        #     optimizer = lr_decay_func(optimizer, lr_decay=lr_decay)
    if args.data_aug == 1:
        if epoch == 10 or epoch == 100:
            optimizer = lr_decay_func(optimizer, lr_decay=lr_decay)                   
    return optimizer
def lr_decay_func_AUC(optimizer, epoch):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= (1/epoch)
    return optimizer

class multiClassHingeLoss(nn.Module):
    def __init__(self):
        super(multiClassHingeLoss, self).__init__()
    def forward(self, output, y):
        index = torch.arange(0, y.size()[0]).long().cuda()
        output_y = output[index, y.data.cuda()].view(-1,1)
        loss = output - output_y + 1.0 
        loss[index, y.data.cuda()] = 0
        loss[loss < 0]=0
        loss = torch.sum(loss, dim=1) / output.size()[1]
        return loss 
hinge_loss = multiClassHingeLoss()
    
#===============================================================
#=== train optimization def
#===============================================================
para_a = torch.zeros(1).cuda()
para_b = torch.zeros(1).cuda()
para_c = torch.zeros(1).cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr)

# optimizer = optim.SGD([{'params': model.parameters()}, {'params':para_a}, {'params':para_b}], lr=args.lr)
# optimizer2 = optim.SGD([{'params':para_c}], lr=args.lr_c)
grad_w = []
grad_a = []
grad_b = []
grad_c = []

sum_w = torch.zeros([1, nh * nw * nc]).cuda()
index = 1

def train(epoch):
    # global optimizer, ssize
    global optimizer, para_a, para_b, para_c, sum_w, index
    model.train()
    optimizer = lr_scheduler(optimizer, epoch)
    # optimizer = lr_decay_func_AUC(optimizer, epoch)
        
    for batch_idx, (x, y) in enumerate(train_loader):
        x_first = x[:int(x.shape[0] / 2)]
        y_first = y[:int(x.shape[0] / 2)]
        x_second = x[int(x.shape[0] / 2):]
        y_second = y[int(x.shape[0] / 2):]
        bs = y.size(0)

        x_first = Variable(x_first.cuda()).unsqueeze(1).float()
        y_cp_pos_first = y_first.cpu().detach().numpy().copy()
        y_cp_pos_first[y_cp_pos_first == -1] = 0
        y_cp_neg_first = y_first.cpu().detach().numpy().copy()
        y_cp_neg_first[y_cp_neg_first == 1] = 0
        y_cp_neg_first[y_cp_neg_first == -1] = 1
        y_cp_pos_first = torch.from_numpy(y_cp_pos_first).long()
        y_cp_neg_first = torch.from_numpy(y_cp_neg_first).long()
        y_first = y_first.long()
        y_first = Variable(y_first.cuda())
        y_cp_pos_first = Variable(y_cp_pos_first.cuda())
        y_cp_neg_first = Variable(y_cp_neg_first.cuda())

        x_second = Variable(x_second.cuda()).unsqueeze(1).float()
        y_cp_pos_second = y_second.cpu().detach().numpy().copy()
        y_cp_pos_second[y_cp_pos_second == -1] = 0
        y_cp_neg_second = y_second.cpu().detach().numpy().copy()
        y_cp_neg_second[y_cp_neg_second == 1] = 0
        y_cp_neg_second[y_cp_neg_second == -1] = 1
        y_cp_pos_second = torch.from_numpy(y_cp_pos_second).long()
        y_cp_neg_second = torch.from_numpy(y_cp_neg_second).long()
        y_second = y_second.long()
        y_second = Variable(y_second.cuda())
        y_cp_pos_second = Variable(y_cp_pos_second.cuda())
        y_cp_neg_second = Variable(y_cp_neg_second.cuda())

        #########################################First###############################################
        h1 = model(x_first)
        loss_1 = (1 / n_train_pos) * ((h1 - para_a) ** 2) * (y_cp_pos_first.view(-1, 1))  # .detach()
        loss_2 = (1 / n_train_neg) * ((h1 - para_b) ** 2) * (y_cp_neg_first.view(-1, 1))
        loss_3 = 2 * (1 + para_c) * ((1 / n_train_neg) * h1 * (y_cp_neg_first.view(-1, 1)) - (1 / n_train_pos) * h1 * (
            y_cp_pos_first.view(-1, 1))) - (1 / (n_train_neg + n_train_pos)) * (para_c ** 2)
        cr_loss = loss_1 + loss_2 + loss_3
        loss = torch.mean(cr_loss)

        ##save old paras
        for para_old in model.parameters():
            para_w_old = para_old.data.detach().clone()
        para_a_old = para_a
        para_b_old = para_b
        para_c_old = para_c

        optimizer.zero_grad()
        loss.requres_grad = True
        # loss.backward(retain_graph=True)
        loss.backward()
        grad_norm = 0
        with torch.no_grad():
            for group in optimizer.param_groups:
                for param in group['params']:
                    if param.requires_grad:
                        grad_norm += param.grad.norm().item() ** 2
            grad_norm = grad_norm ** (1. / 2)
        grad_w.append(grad_norm)
        optimizer.step()

        ###projection w
        for p in model.parameters():
            if torch.norm(p).item() > args.proj_wab:
                print(torch.norm(p).item())
                p.data *= (args.proj_wab / torch.norm(p).item())

        h1_new = model(x_first).detach()
        grad_a_indi = torch.mean((2 / n_train_pos) * (para_a - h1_new) * (y_cp_pos_first.view(-1, 1)))
        grad_a.append(torch.norm(grad_a_indi).item())
        para_a = para_a - args.lr * grad_a_indi
        if torch.norm(para_a).item() > args.proj_wab:
            print(torch.norm(para_a).item())
            para_a = para_a * (args.proj_wab / torch.norm(para_a).item())

        grad_b_indi = torch.mean((2 / n_train_neg) * (para_b - h1_new) * (y_cp_neg_first.view(-1, 1)))
        grad_b.append(torch.norm(grad_b_indi).item())
        para_b = para_b - args.lr * grad_b_indi
        if torch.norm(para_b).item() > args.proj_wab:
            print(torch.norm(para_b).item())
            para_b = para_b * (args.proj_wab / torch.norm(para_b).item())

        grad_c_indi = torch.mean((2 / n_train_neg) * h1_new * (y_cp_neg_first.view(-1, 1)) \
                                 - (2 / n_train_pos) * h1_new * (y_cp_pos_first.view(-1, 1)) \
                                 - (2 / (n_train_neg + n_train_pos)) * para_c)
        grad_c.append(torch.norm(grad_c_indi).item())
        para_c = para_c + args.lr_c * grad_c_indi
        if torch.norm(para_c).item() > args.proj_c:
            print(torch.norm(para_c).item())
            para_c = para_c * (args.proj_c / torch.norm(para_c).item())

        ###average output
        index = index + 1
        for para in model.parameters():
            sum_w += para.data

        #########################################Second###############################################
        h1 = model(x_second)
        loss_1_second = (1 / n_train_pos) * ((h1 - para_a) ** 2) * (y_cp_pos_second.view(-1, 1))  # .detach()
        loss_2_second = (1 / n_train_neg) * ((h1 - para_b) ** 2) * (y_cp_neg_second.view(-1, 1))
        loss_3_second = 2 * (1 + para_c) * (
                    (1 / n_train_neg) * h1 * (y_cp_neg_second.view(-1, 1)) - (1 / n_train_pos) * h1 * (
                y_cp_pos_second.view(-1, 1))) - (1 / (n_train_neg + n_train_pos)) * (para_c ** 2)
        cr_loss_second = loss_1_second + loss_2_second + loss_3_second
        loss_second = torch.mean(cr_loss_second)

        for para_new in model.parameters():
            para_new.data = para_w_old
        optimizer.zero_grad()
        loss_second.requres_grad = True
        # loss_second.backward(retain_graph=True)
        loss_second.backward()
        optimizer.step()

        ###projection w
        for p in model.parameters():
            if torch.norm(p).item() > args.proj_wab:
                print(torch.norm(p).item())
                p.data *= (args.proj_wab / torch.norm(p).item())

        h1_new = model(x_second).detach()
        grad_a_indi = torch.mean((2 / n_train_pos) * (para_a - h1_new) * (y_cp_pos_second.view(-1, 1)))
        grad_a.append(torch.norm(grad_a_indi).item())
        para_a = para_a_old - args.lr * grad_a_indi
        if torch.norm(para_a).item() > args.proj_wab:
            print(torch.norm(para_a).item())
            para_a = para_a * (args.proj_wab / torch.norm(para_a).item())

        grad_b_indi = torch.mean((2 / n_train_neg) * (para_b - h1_new) * (y_cp_neg_second.view(-1, 1)))
        grad_b.append(torch.norm(grad_b_indi).item())
        para_b = para_b_old - args.lr * grad_b_indi
        if torch.norm(para_b).item() > args.proj_wab:
            print(torch.norm(para_b).item())
            para_b = para_b * (args.proj_wab / torch.norm(para_b).item())

        grad_c_indi = torch.mean((2 / n_train_neg) * h1_new * (y_cp_neg_second.view(-1, 1)) \
                                 - (2 / n_train_pos) * h1_new * (y_cp_pos_second.view(-1, 1)) \
                                 - (2 / (n_train_neg + n_train_pos)) * para_c)
        grad_c.append(torch.norm(grad_c_indi).item())
        para_c = para_c_old + args.lr_c * grad_c_indi
        if torch.norm(para_c).item() > args.proj_c:
            print(torch.norm(para_c).item())
            para_c = para_c * (args.proj_c / torch.norm(para_c).item())

    optimizer.zero_grad()


#===============================================================
#=== train/test output def
#===============================================================    
def output(data_loader):
    global sum_w, index
    if data_loader == train_loader:    
        model.train()
    elif data_loader == test_loader:
        model.eval()
        model_test.eval()
    total_loss = 0    
    total_correct = 0      
    total_size = 0
    AUC = 0
    AUC_avg = 0
    AUC_size = 0
    for batch_idx, (x, y) in enumerate(data_loader):
        if args.dataset == 'mnist_binary' or args.dataset == 'ijcnn1':
            x= Variable(x.cuda()).unsqueeze(1).float()
            # x = Variable(x.cuda()).to(torch.float32)
            y_cp_pos = y.cpu().detach().numpy().copy()
            y_cp_pos[y_cp_pos==-1]=0
            y_cp_neg = y.cpu().detach().numpy().copy()
            y_cp_neg[y_cp_neg==1]=0
            y_cp_neg[y_cp_neg == -1] = 1
            y_cp_pos = torch.from_numpy(y_cp_pos).long()
            y_cp_neg = torch.from_numpy(y_cp_neg).long()
            y=y.long()
            y = Variable(y.cuda())
            y_cp_pos = Variable(y_cp_pos.cuda())
            y_cp_neg = Variable(y_cp_neg.cuda())
        else:
            x = Variable(x.cuda())
            y = Variable(y.cuda())

        h1 = model(x)
        if args.model == 'SVM':
            total_loss += torch.mean(hinge_loss(h1, y)).item() * y.size(0)
        elif args.model == 'AUC_LeNet' or args.model == 'AUC_MLP' or args.model == 'AUC_linear':
            loss_1=(1/n_train_pos)*((h1-para_a)**2)*(y_cp_pos.view(-1,1))
            loss_2 = (1 / n_train_neg) * ((h1 - para_b) ** 2) * (y_cp_neg.view(-1,1))
            loss_3 = 2*(1+para_c)*((1 / n_train_neg)*h1*(y_cp_neg.view(-1,1))-(1 / n_train_pos)*h1*(y_cp_pos.view(-1,1)))-(1/(n_train_neg+n_train_pos))*(para_c**2)
            auc_loss = loss_1+loss_2+loss_3
            total_loss += torch.mean(auc_loss).item() * y.size(0)
        else:
            total_loss += F.cross_entropy(h1, y).item() * y.size(0)
        if data_loader == test_loader:
            AUC += roc_auc_score(y.cpu().detach().numpy(), np.squeeze(h1.cpu().detach().numpy()))
        if data_loader == test_loader:
            for p in model_test.parameters():
                p.data = sum_w*(1/index)
            h2 = model_test(x)
            AUC_avg +=  roc_auc_score(y.cpu().detach().numpy(), np.squeeze(h2.cpu().detach().numpy()))
        total_size += y.size(0)
        AUC_size += 1

    # print
    if data_loader == test_loader:
        AUC = 100. * AUC/AUC_size
        AUC_avg = 100. * AUC_avg / AUC_size
    total_loss /= total_size
    if data_loader == train_loader:    
        out_str = 'tr_l={:.3f} AUC={:.2f}:'.format(total_loss, AUC)
        if grad_w:
            # out_str = 'tr_l={} AUC={:.2f} max_grad_w={:.5f} max_grad_a={:.7f} max_grad_b={:.7f} max_grad_c={:.7f}:'\
            #     .format(total_loss, AUC, np.max(grad_w), np.max(grad_a), np.max(grad_b), np.max(grad_c))
            out_str = 'tr_l={} AUC={:.2f} max_grad_w={:.5f} max_grad_a={:.7f} max_grad_b={:.7f} max_grad_c={:.7f}:' \
                .format(total_loss, 1, np.max(grad_w), np.max(grad_a), np.max(grad_b), np.max(grad_c))
        else:
            # out_str = 'tr_l={} AUC={:.3f}:'.format(total_loss, AUC)
            out_str = 'tr_l={} AUC={:.2f}:'.format(total_loss, 1)
    elif data_loader == test_loader:
        # out_str = 'te_l={:.3f} AUC={:.2f}:'.format(total_loss, AUC)
        out_str = 'te_l={} AUC={:.3f} AUC_avg={:.3f}:'.format(total_loss, AUC, AUC_avg)
    print(out_str, end=' ')
    filep.write(out_str + ' ') 
    return (total_loss, AUC)

#===============================================================
#=== start computation
#===============================================================    
#== for plot
pl_result = np.zeros((end_epoch+1, 3, 2))  # epoch * (train, test, time) * (loss , acc) 
#== main loop start
time_start = time.time()
for epoch in count(0):
    out_str = str(epoch)
    print(out_str, end=' ') 
    filep.write(out_str + ' ')
    if epoch >= 1:
        train(epoch)
    pl_result[epoch, 0, :] = output(train_loader)
    pl_result[epoch, 1, :] = output(test_loader)
    time_current = time.time() - time_start
    pl_result[epoch, 2, 0] = time_current
    np.save(result_path + '_' + 'pl', pl_result)    
    out_str = 'time={:.1f}:'.format(time_current) 
    print(out_str)    
    filep.write(out_str + '\n')   
    if epoch == end_epoch:
        break