import os
from pytorch_fid.fid_score import calculate_fid_given_paths
from PSNR import cal_psnr_given_paths
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pytorch_fid.inception import InceptionV3

import torch

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=50,
                    help='Batch size to use')
parser.add_argument('--num-workers', type=int,
                    help=('Number of processes to use for data loading. '
                          'Defaults to `min(8, num_cpus)`'))
parser.add_argument('--device', type=str, default=None,
                    help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument('--dims', type=int, default=2048,
                    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
                    help=('Dimensionality of Inception features to use. '
                          'By default, uses pool3 features'))
parser.add_argument('--save-stats', action='store_true',
                    help=('Generate an npz archive from a directory of samples. '
                          'The first path is used as input and the second as output.'))
parser.add_argument('--path', type=str, nargs=2,
                    help=('Paths to the generated images or '
                          'to .npz statistic files'))

args = parser.parse_args()

args.path = ["output", "raw_data/valid/valid"]

if args.device is None:
    device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
else:
    device = torch.device(args.device)

if args.num_workers is None:
    try:
        num_cpus = len(os.sched_getaffinity(0))
    except AttributeError:
        # os.sched_getaffinity is not available under Windows, use
        # os.cpu_count instead (which may not return the *available* number
        # of CPUs).
        num_cpus = os.cpu_count()

    num_workers = min(num_cpus, 8) if num_cpus is not None else 0
else:
    num_workers = args.num_workers


def pick_best_ckpt(args):
    min_fid = 100
    max_psnr = 0
    best_fid_ckpt, best_psnr_ckpt = None, None
    for i in range(100):
        os.system("python sample_wmgm.py --test_epoch %d"%i)
        fid_value = calculate_fid_given_paths(args.path,
                                            args.batch_size,
                                            device,
                                            args.dims,
                                            num_workers)
        psnr_value = cal_psnr_given_paths(args.path[0], args.path[1])
        if fid_value < min_fid:
            min_fid = fid_value
            best_fid_ckpt = i
        if psnr_value > max_psnr:
            max_psnr = psnr_value
            best_psnr_ckpt = i
        
        print('Epoch %d: FID %.4f, PSNR %.4f' % (i, fid_value, psnr_value))

    print("Minimum FID is %.4f at epoch %d\n" % (min_fid, best_fid_ckpt))
    print("Maximum PSNR is %.4f at epoch %d\n" % (max_psnr, best_psnr_ckpt))


if __name__ == '__main__':
    pick_best_ckpt(args)
