
import os

from napkin_mnist.FID_scores.pytorch_fid.fid_score import calculate_fid

# version =["rep_image", "full_image", "mod_image"]
version =["IDGEN"]
fid={}
for alg in version:
    fid[alg]=[]
    for epoch in [0,50,100,150,200, 250, 300]:
    # for epoch in [295]:
        print(f'{alg}: Calculating FID for epoch :{epoch}')
        root= "/root/PycharmProjects/IDGEN/napkin_mnist"
        path1=f"{root}/FID_scores/trueP_Y"
        path2=f"{root}/FID_scores/{alg}{epoch:003}"


        ret= calculate_fid(path1, path2)
        fid[alg].append(ret)


for alg in fid:
    print(alg, fid[alg])


# ncm: Calculating FID for epoch :0
# 100%|██████████| 232/232 [00:07<00:00, 30.06it/s]
# 100%|██████████| 60/60 [00:02<00:00, 24.20it/s]
# FID:  109.64472189270089
# ncm: Calculating FID for epoch :50
# 100%|██████████| 232/232 [00:06<00:00, 36.88it/s]
# 100%|██████████| 60/60 [00:02<00:00, 24.85it/s]
# FID:  170.04853292875444
# ncm: Calculating FID for epoch :100
# 100%|██████████| 232/232 [00:06<00:00, 36.85it/s]
# 100%|██████████| 60/60 [00:02<00:00, 24.15it/s]
# FID:  98.51317595316115
# ncm: Calculating FID for epoch :150
# 100%|██████████| 232/232 [00:06<00:00, 36.70it/s]
# 100%|██████████| 60/60 [00:02<00:00, 24.03it/s]
# FID:  106.89447081242254
# ncm: Calculating FID for epoch :200
# 100%|██████████| 232/232 [00:06<00:00, 36.60it/s]
# 100%|██████████| 60/60 [00:02<00:00, 24.10it/s]
# FID:  102.30110946517222
# ncm: Calculating FID for epoch :250
# 100%|██████████| 232/232 [00:06<00:00, 36.56it/s]
# 100%|██████████| 60/60 [00:02<00:00, 23.97it/s]
# FID:  133.57297189275045

# 100%|██████████| 232/232 [00:07<00:00, 30.44it/s]
# 100%|██████████| 60/60 [00:02<00:00, 25.96it/s]
# FID:  161.39099424596236
# ncm [161.39099424596236]

# IDGEN [27.04184721048469, 28.245207310276555, 28.352652378930145, 27.961319950896296, 28.011457875796765, 28.075970840888658, 27.56625944907009]