"""General-purpose test script for image-to-image translation.

Once you have trained your model with train.py, you can use this script to test the model.
It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.

It first creates model and dataset given the option. It will hard-code some parameters.
It then runs inference for '--num_test' images and save results to an HTML file.

Example (You need to train models first or download pre-trained models from our website):
    Test a CycleGAN model (both sides):
        python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

    Test a CycleGAN model (one side only):
        python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

    The option '--model test' is used for generating CycleGAN results only for one side.
    This option will automatically set '--dataset_mode single', which only loads the images from one set.
    On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
    which is sometimes unnecessary. The results will be saved at ./results/.
    Use '--results_dir <directory_path_to_save_result>' to specify the results directory.

    Test a pix2pix model:
        python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

See options/base_options.py and options/test_options.py for more test options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""

'''
def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)  # define shared options
        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        # Dropout and Batchnorm has different behavioir during training and test.
        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run | disabled in WKD, we test fid with all the testting images for stable results')
        # rewrite devalue values
        parser.set_defaults(model='test')
        # To avoid cropping, the load_size should be the same as crop_size
        parser.set_defaults(load_size=parser.get_default('crop_size'))
        self.isTrain = False
        return parser
'''
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html
import copy
from fid_score import *
from inception import*

def evaluate(model, opt):
    #model = copy.deepcopy(model)
    opt = copy.deepcopy(opt)
    #   opt = TestOptions().parse()  # get test options
    opt.results_dir = './results/'
    opt.aspect_ratio = 1.0
    opt.phase = 'test'
    opt.eval = True
    opt.epoch = "temporal"
    opt.distill = False
    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 0
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options


    #model = create_model(opt)      # create a model given opt.model and other options
    #model.setup(opt)               # regular setup: load and print networks; create schedulers
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory
    if opt.load_iter > 0:  # load_iter is 0 by default
        web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter)
   # print('creating web directory', web_dir)
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        model.set_input(data)  # unpack data from data loader
        model.test()           # run inference
        visuals = model.get_current_imgs()  # get image results
        img_path = model.get_image_paths()     # get image paths
        if i % 20 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))

        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    webpage.save()  # save the HTML
    model.train()
    #del model
    web_dir = web_dir + "/images/"
    fake_A_path = web_dir + "/fakeA"
    fake_B_path = web_dir + "/fakeB"
    real_A_path = web_dir + "/realA"
    real_B_path = web_dir + "/realB"

    if not os.path.exists(fake_A_path):
        os.makedirs(fake_A_path)
    if not os.path.exists(fake_B_path):
        os.makedirs(fake_B_path)
    if not os.path.exists(real_A_path):
        os.makedirs(real_A_path)
    if not os.path.exists(real_B_path):
        os.makedirs(real_B_path)

    name_list = os.listdir(web_dir)
    for item in name_list:
        if 'real_B' in item:
            name = item.split("_")[0] + ".png"
            os.system("mv " + web_dir + item +" "+ web_dir + "/realB/" + item)
        if 'fake_B' in item and 'teacher' not in item:
            name = item.split("_")[0] + ".png"
            os.system("mv " + web_dir + item +" "+ web_dir + "/fakeB/" + item)
        if 'fake_A' in item and 'teacher' not in item:
            name = item.split("_")[0] + ".png"
            os.system("mv " + web_dir + item +" "+ web_dir + "/fakeA/" + item)
        if 'real_A' in item:
            name = item.split("_")[0] + ".png"
            os.system("mv " + web_dir + item +" "+ web_dir + "/realA/" + item)
    print('calculating fids between', fake_B_path, real_B_path)
    fid_a = calculate_fid_given_paths([fake_A_path, real_A_path], 32, "cuda", 2048)
    fid_b = calculate_fid_given_paths([fake_B_path, real_B_path], 32, "cuda", 2048)
    return fid_a, fid_b