"""General-purpose training script for image-to-image translation.

This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').

It first creates model, dataset, and visualizer given the option.
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
The script supports continue/resume training. Use '--continue_train' to resume your previous training.

Example:
    Train a CycleGAN model:
        python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
    Train a pix2pix model:
        python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

See options/base_options.py and options/train_options.py for more training options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import torch
import algorithms
from util import misc
import datasets
if __name__ == '__main__':
    opt = TrainOptions().parse()   # get training options
    if opt.dataset == 'ColoredMNIST':
        hparams = {'input_shape': (3,28,28),
                   'num_classes': 2,
                   'opt': 'SGD',
                   'lr': 1e-1,
                   'weight_decay': 1e-4,
                   'sch_size':600
        }
    if opt.dataset in vars(datasets):
        dataset = vars(datasets)[opt.dataset](opt.data_dir, opt.bias, [], hparams)
    else:
        raise NotImplementedError
    #print('The number of training images = %d' % dataset_size)
    in_splits = []
    out_splits = []
    for env_i, env in enumerate(dataset):
        out, in_ = misc.split_dataset(env, int(len(env) * opt.holdout_fraction),
                                      misc.seed_hash(opt.trial_seed, env_i))
        in_weights, out_weights = None, None
        in_splits.append(in_)
        out_splits.append(out)

    for i, (env) in enumerate(in_splits):
        print("Env:{}, dataset len:{}".format(i, len(env)))

    # class of env: ImageFolder
    train_loader = torch.utils.data.DataLoader(dataset=in_splits[0], batch_size=opt.batch_size, shuffle=True,
                                                 num_workers=dataset.N_WORKERS)
    eval_loaders = torch.utils.data.DataLoader(dataset=out_splits[0], batch_size=opt.batch_size, shuffle=True,
                                                 num_workers=dataset.N_WORKERS)
    algorithm_class = algorithms.get_algorithm_class('ERM')
    algorithm_spur = algorithm_class(hparams['input_shape'], hparams['num_classes'],1, hparams)

    checkpoint = torch.load(f'stage3_model_{opt.bias[0]}')
    algorithm_spur.load_state_dict(checkpoint['model_dict'])
    algorithm_spur.cuda()
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
    total_iters = 0                # the total number of training iterations
    dataset_size = len(train_loader.dataset)

    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch
        #model.update_learning_rate()    # update learning rates in the beginning of every epoch.

        for i, data in enumerate(train_loader):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data, net_spur=algorithm_spur)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            # if opt.eval:
            #     model.eval()

            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
                if opt.eval:
                    model.eval()
            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
