import os
import json
import skimage.io
import skimage.transform
import skimage.metrics
import numpy as np

from util import html
from os import listdir
from collections import OrderedDict
from util.visualizer import Visualizer
from metrics.fid import calculate_fid_given_paths

def psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    else:
        PIXEL_MAX = 255.0
        return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

def save_json(json_file, filename):
    with open(filename, 'w') as f:
        json.dump(json_file, f, indent=4, sort_keys=False)

def eval_all_metrics(opt, trainer, dataloader, visualizer, epoch, iteration):
    # create a webpage that summarizes the all results

    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir,
                        'Experiment = %s, Phase = %s, Epoch = %s' %
                        (opt.name, opt.phase, opt.which_epoch))


    if len(opt.gpu_ids) > 0:

        model = trainer.pix2pix_model.module
    else:
        model = trainer.pix2pix_model

    model.eval()
    for i, data_i in enumerate(dataloader):
        if i * opt.batchSize >= opt.how_many:
            break
    
        generated = model(data_i, mode='inference')
    
        img_path = data_i['path']
        for b in range(generated.shape[0]):
            print('process image... %s' % img_path[b])
            visuals = OrderedDict([('input_label', data_i['label'][b]),
                                   ('synthesized_image', generated[b])])
            visualizer.save_images(webpage, visuals, img_path[b:b + 1])

    webpage.save()
    model.train()
    calculate_fid_for_all_tasks(opt,  epoch, iteration)
    psnr_ssim_rmse(opt,  epoch, iteration)

def calculate_fid_for_all_tasks(opt,  epoch, iteration):
    fid_values = OrderedDict()
    path_real = opt.image_dir
    path_fake = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch), 'images', 'synthesized_image')
    fid_value = calculate_fid_given_paths(
        paths=[path_real, path_fake],
        img_size=opt.crop_size,
        batch_size=50)
    fid_values['%s_%s_%d_%d' % (opt.phase, opt.which_epoch, epoch, iteration)] = fid_value

    # report FID values
    if not os.path.exists(os.path.join(opt.results_dir, opt.name)):
        os.makedirs(os.path.join(opt.results_dir, opt.name))
    filename = os.path.join(os.path.join(opt.results_dir, opt.name), 'FID_epoch_%d_ite_%.5i.json' % (epoch, iteration))
    save_json(fid_values, filename)

def psnr_ssim_rmse(opt,  epoch, iteration):

    ssim_values = OrderedDict()
    psnr_values = OrderedDict()
    rmse_values = OrderedDict()
    path_real = opt.image_dir
    path_fake = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch), 'images', 'synthesized_image')
    ssim_score_list=[]
    psnr_score_list=[]
    rmse_score_list=[]
    for img_name in listdir(path_real):
        fake_img = skimage.io.imread(os.path.join(path_fake, img_name.split('.')[0]+'.png'))/255
        real_img = skimage.io.imread(os.path.join(path_real, img_name))
        real_img = skimage.transform.resize(real_img, (256, 256), 3, preserve_range=True, anti_aliasing=False)/255
        # ssim
        ssim=skimage.metrics.structural_similarity(real_img, fake_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False)
        ssim_score_list.append(ssim)
        # psnr : skimage 
        psnr = skimage.metrics.peak_signal_noise_ratio(real_img, fake_img)
        psnr_score_list.append(psnr)
        # rmse: skimage
        rmse=skimage.metrics.normalized_root_mse(real_img, fake_img)
        rmse_score_list.append(rmse)

    ssim_values['%s_%s_%d_%d' % (opt.phase, opt.which_epoch, epoch, iteration)] = np.mean(ssim_score_list)
    psnr_values['%s_%s_%d_%d' % (opt.phase, opt.which_epoch, epoch, iteration)] = np.mean(psnr_score_list)
    rmse_values['%s_%s_%d_%d' % (opt.phase, opt.which_epoch, epoch, iteration)] = np.mean(rmse_score_list)

    # report SSIM values
    if not os.path.exists(os.path.join(opt.results_dir, opt.name)):
        os.makedirs(os.path.join(opt.results_dir, opt.name))

    filename_ssim = os.path.join(os.path.join(opt.results_dir, opt.name), 'SSIM_epoch_%d_ite_%.5i.json' % (epoch, iteration))
    filename_psnr = os.path.join(os.path.join(opt.results_dir, opt.name), 'PSNR_epoch_%d_ite_%.5i.json' % (epoch, iteration))
    filename_rmse = os.path.join(os.path.join(opt.results_dir, opt.name), 'RMSE_epoch_%d_ite_%.5i.json' % (epoch, iteration))

    save_json(ssim_values, filename_ssim)
    save_json(psnr_values, filename_psnr)
    save_json(rmse_values, filename_rmse)





