"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import sys
import data

from collections import OrderedDict
from options.train_options import TrainOptions
from util.iter_counter import IterationCounter
from util.stylegan_data import stylegan, generating_data 
from util.visualizer import Visualizer
from trainers.pix2pix_trainer import Pix2PixTrainer


# ######## added stuff
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="1"



# parse options
opt = TrainOptions().parse()

# print options to help debugging
print(' '.join(sys.argv))

# load the dataset
dataloader = data.create_dataloader(opt)

# create trainer for our model
trainer = Pix2PixTrainer(opt)
StyleGAN =True
if StyleGAN:
    style_g_running, style_discriminator = stylegan(opt.stylegan_ckpt)

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))

# create tool for visualization
visualizer = Visualizer(opt)

for epoch in iter_counter.training_epochs():
    iter_counter.record_epoch_start(epoch)
    for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
        iter_counter.record_one_iteration()

        # Training
        # train generator
        style_g_running.zero_grad()
        style_discriminator.zero_grad()
        if StyleGAN:
            set_1_2= generating_data(style_g_running, style_discriminator, opt.batchSize)
        if i % opt.D_steps_per_G == 0:
            trainer.run_generator_one_step(data_i, set_1_2, style_discriminator=style_discriminator)
        # train discriminator
        trainer.run_discriminator_one_step(data_i, set_1_2, style_discriminator=style_discriminator)

        # Visualizations
        if iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
            visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                            losses, iter_counter.time_per_iter)
            visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

        if iter_counter.needs_displaying():
            set1,_ = set_1_2
            #gen_in1, Fake_img_1_1, G_feature_1_1, D_feature_1_1, Fake_img_1_2, G_feature_1_2, D_feature_1_2, Fake_img_1_mix_1_2, G_feature_1_mix_1_2, D_feature_1_mix_1_2 = set1  
            #gen_in2, Fake_img_2_1, G_feature_2_1, D_feature_2_1, Fake_img_2_2, G_feature_2_2, D_feature_2_2, Fake_img_2_mix_1_2, G_feature_2_mix_1_2, D_feature_2_mix_1_2 = set2 
            visuals = OrderedDict([('input_label', data_i['label']),
                                   ('synthesized_image', trainer.get_latest_generated()),
                                   ('real_image', set1[1])])
            visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

        if iter_counter.needs_saving():
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

    trainer.update_learning_rate(epoch)
    iter_counter.record_epoch_end()

    if epoch % opt.save_epoch_freq == 0 or \
       epoch == iter_counter.total_epochs:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, iter_counter.total_steps_so_far))
        trainer.save('latest')
        trainer.save(epoch)

print('Training was successfully finished.')
