from pytorch_fid.fid_score import calculate_activation_statistics, calculate_frechet_distance
from pytorch_fid.inception import InceptionV3   

import os

def calculate_fid_given_image_sets(image_sets, batch_size, device, dims, num_workers=1):
    """Calculates the FID of two sets of images"""
    for img_set in image_sets:
        for p in img_set:
            if not os.path.exists(p):
                raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx]).to(device)

    m1, s1 = calculate_activation_statistics(image_sets[0], model, batch_size,
                                        dims, device, num_workers)
    m2, s2 = calculate_activation_statistics(image_sets[1], model, batch_size,
                                        dims, device, num_workers)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value