import time
import torch
import random
import itertools

import torch.nn as nn

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image, make_grid

from utils import *
from options import TrainOptions
from models import BUM
from losses import GanLoss, DLoss
from datasets import UnpairedImgDataset

print('---------------------------------------- step 1/5 : parameters preparing... ----------------------------------------')
opt = TrainOptions().parse()

set_random_seed(opt.seed, deterministic=False)

models_dir, log_dir, train_images_dir, val_images_dir = prepare_dir(opt.results_dir, opt.experiment, delete=(not opt.resume))

writer = SummaryWriter(log_dir=log_dir)

print('---------------------------------------- step 2/5 : data loading... ------------------------------------------------')
print('training data loading...')
train_dataset = UnpairedImgDataset(data_source=opt.train_source, mode='train', random_resize=opt.random_resize, crop=opt.crop)
train_dataloader = DataLoader(train_dataset, batch_size=opt.train_bs, shuffle=True, num_workers=opt.num_workers, pin_memory=True)
print('successfully loading training pairs. =====> qty:{} bs:{}'.format(len(train_dataset),opt.train_bs))

print('validating data 1 loading...')
val_dataset1 = UnpairedImgDataset(data_source=opt.val_source1, mode='val1', random_resize=opt.random_resize)
val_dataloader1 = DataLoader(val_dataset1, batch_size=opt.val_bs, shuffle=False, num_workers=opt.num_workers, pin_memory=True)
print('successfully loading validating pairs 1. =====> qty:{} bs:{}'.format(len(val_dataset1),opt.val_bs))

print('validating data 2 loading...')
val_dataset2 = UnpairedImgDataset(data_source=opt.val_source2, mode='val2', random_resize=opt.random_resize)
val_dataloader2 = DataLoader(val_dataset2, batch_size=opt.val_bs, shuffle=False, num_workers=opt.num_workers, pin_memory=True)
print('successfully loading validating pairs 2. =====> qty:{} bs:{}'.format(len(val_dataset2),opt.val_bs))

print('---------------------------------------- step 3/5 : model defining... ----------------------------------------------')
model = BUM(opt.num_res, opt.model_mode).cuda()
print_para_num(model)
print_para_num(model.G_AB)
print_para_num(model.D_B)

if opt.pretrained is not None:
    model.load_state_dict(torch.load(opt.pretrained))
    print('successfully loading pretrained model.')
    
print('---------------------------------------- step 4/5 : requisites defining... -----------------------------------------')
# Losses
criterion_cycle = nn.L1Loss()
criterion_sr = nn.L1Loss()
criterion_gan = GanLoss(gan_type=opt.gan_type)
criterion_d = DLoss(gan_type=opt.gan_type)

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(model.G_AB.parameters(), model.G_BA.parameters()), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(model.D_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

# Learning rate update schedulers
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(optimizer_G, [50, 100, 150, 200, 250, 300], 0.5)
scheduler_D_B = torch.optim.lr_scheduler.MultiStepLR(optimizer_D_B, [50, 100, 150, 200, 250, 300], 0.5)

print('---------------------------------------- step 5/5 : training... ----------------------------------------------------')
def main():
    
    optimal = [0., 0.]
    start_epoch = 1
    if opt.resume:
        state = torch.load(models_dir + '/latest.pth')
        model.load_state_dict(state['model'])
        optimizer_G.load_state_dict(state['optimizer_G'])
        optimizer_D_B.load_state_dict(state['optimizer_D_B'])
        scheduler_G.load_state_dict(state['scheduler_G'])
        scheduler_D_B.load_state_dict(state['scheduler_D_B'])
        start_epoch = state['epoch'] + 1
        optimal = state['optimal']
        print('Resume from epoch %d' % (start_epoch), optimal)
    
    for epoch in range(start_epoch, opt.n_epochs + 1):
        train(epoch, optimal)
        
        if (epoch) % opt.val_gap == 0:
            val(epoch, optimal)
        
    writer.close()
    
def train(epoch, optimal):
    model.train()
    
    max_iter = len(train_dataloader)
        
    iter_D_meter = AverageMeter()
    iter_G_meter = AverageMeter()
    iter_gan_meter = AverageMeter()
    iter_cycle_meter = AverageMeter()
    iter_sr_meter = AverageMeter()
    iter_timer = Timer()
    
    for i, (imgA, imgB) in enumerate(train_dataloader):
        imgA, imgB = imgA.cuda(), imgB.cuda()
        cur_batch = imgA.shape[0]
        
        fakeB, reconA = model.forward_G(imgA)
        
        # -----------------------
        #  Train Discriminator
        # -----------------------
        
        # -----------------------
        #  Train Discriminator B
        optimizer_D_B.zero_grad()
        # foward
        fakeB_valid, imgB_valid = model.forward_D_B(fakeB.detach(), imgB)
        # compute loss, backward & update
        loss_D_B = criterion_d(fakeB_valid, imgB_valid)
        loss_D_B.backward()
        optimizer_D_B.step()
            
        loss_D = loss_D_B
        
        # ------------------
        #  Train Generator
        # ------------------
            
        model.freeze_D_B()
            
        optimizer_G.zero_grad()
        # compute loss
        loss_cycle = criterion_cycle(reconA, imgA)
        loss_sr = (criterion_sr(model.G_AB(imgB), imgB) + criterion_sr(model.G_BA(imgA), imgA)) / 2
        loss_gan = criterion_gan(model.D_B(fakeB))
        loss_G = loss_gan + opt.lambda_cycle * loss_cycle + opt.lambda_sr * loss_sr
        # backward & update
        loss_G.backward()
        optimizer_G.step()
            
        model.unfreeze_D_B()
        
        # record
        iter_D_meter.update(loss_D.item()*cur_batch, cur_batch)
        iter_G_meter.update(loss_G.item()*cur_batch, cur_batch)
        iter_gan_meter.update(loss_gan.item()*cur_batch, cur_batch)
        iter_cycle_meter.update(loss_cycle.item()*cur_batch, cur_batch)
        iter_sr_meter.update(loss_sr.item()*cur_batch, cur_batch)
        
        # print
        if (i+1) % opt.print_gap == 0:
            print('Training: Epoch[{:0>4}/{:0>4}] Iteration[{:0>4}/{:0>4}] Loss_D: {:.4f} Loss_G: {:.4f} Loss_gan: {:.4f} Loss_cycle: {:.4f} Loss_sr: {:.4f} Time: {:.4f}'.format(epoch, opt.n_epochs, i + 1, max_iter, iter_D_meter.average(), iter_G_meter.average(), iter_gan_meter.average(), iter_cycle_meter.average(), iter_sr_meter.average(), iter_timer.timeit()))
            writer.add_scalar('loss_D', iter_D_meter.average(auto_reset=True), i+1 + (epoch - 1) * max_iter)
            writer.add_scalar('loss_G', iter_G_meter.average(auto_reset=True), i+1 + (epoch - 1) * max_iter)
            writer.add_scalar('loss_gan', iter_gan_meter.average(auto_reset=True), i+1 + (epoch - 1) * max_iter)
            writer.add_scalar('loss_cycle', iter_cycle_meter.average(auto_reset=True), i+1 + (epoch - 1) * max_iter)
            writer.add_scalar('loss_sr', iter_sr_meter.average(auto_reset=True), i+1 + (epoch - 1) * max_iter)
            
            save_image(torch.cat((imgA, fakeB, imgB), 3), train_images_dir + '/img_epoch_{:0>4}_iter_{:0>4}.png'.format(epoch, i+1), nrow=1, normalize=True, scale_each=True)
            
    writer.add_scalar('lr', scheduler_G.get_last_lr()[0], epoch)
    
    torch.save({'model': model.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(), 'scheduler_G': scheduler_G.state_dict(), 'scheduler_D_B': scheduler_D_B.state_dict(), 'epoch': epoch, 'optimal': optimal}, models_dir + '/latest.pth')
    
    scheduler_G.step()
    scheduler_D_B.step()
    
def val(epoch, optimal):
    model.eval()
    
    print(''); print('Validating...', end=' ')

    psnr_meter1 = AverageMeter()
    psnr_meter2 = AverageMeter()
    timer1 = Timer()
    timer2 = Timer()
    
    for i, (imgA, imgB) in enumerate(val_dataloader1):
        imgA, imgB = imgA.cuda(), imgB.cuda()
        
        with torch.no_grad():
            fakeB = model.G_AB(imgA)
        
        psnr_meter1.update(get_metrics(fakeB, imgB), imgB.shape[0])
        
        if i < 2:
            save_image(torch.cat((imgA, fakeB, imgB), 3), val_images_dir + '/val1_epoch_{:0>4}_iter_{:0>4}.png'.format(epoch, i+1), nrow=1, normalize=True, scale_each=True)
    
    for i, (imgA, imgB) in enumerate(val_dataloader2):
        imgA, imgB = imgA.cuda(), imgB.cuda()
        
        with torch.no_grad():
            fakeB = model.G_AB(imgA)
        
        psnr_meter2.update(get_metrics(fakeB, imgB), imgB.shape[0])
        
        if i < 2:
            save_image(torch.cat((imgA, fakeB, imgB), 3), val_images_dir + '/val2_epoch_{:0>4}_iter_{:0>4}.png'.format(epoch, i+1), nrow=1, normalize=True, scale_each=True)
    
    print('Epoch[{:0>4}/{:0>4}] PSNR1: {:.4f} PSNR2: {:.4f} Time1: {:.4f} Time2: {:.4f}'.format(epoch, opt.n_epochs, psnr_meter1.average(), psnr_meter2.average(), timer1.timeit(), timer2.timeit())); print('')
    
    if optimal[0] < psnr_meter1.average():
        optimal[0] = psnr_meter1.average()
        torch.save(model.state_dict(), models_dir + '/optimal_{:.2f}_{:.2f}_epoch_{:0>4}.pth'.format(optimal[0], psnr_meter2.average(), epoch))
    if optimal[1] < psnr_meter2.average():
        optimal[1] = psnr_meter2.average()
        torch.save(model.state_dict(), models_dir + '/z_enhancing_{:.2f}_{:.2f}_epoch_{:0>4}.pth'.format(optimal[1], psnr_meter1.average(), epoch))
    
    writer.add_scalar('psnr1', psnr_meter1.average(), epoch)
    writer.add_scalar('psnr2', psnr_meter2.average(), epoch)
    
    if epoch % 50 == 0:
        torch.save(model.state_dict(), models_dir + '/epoch_{:0>4}.pth'.format(epoch))
    
if __name__ == '__main__':
    main()
    