
import os
import torch
import random
import numpy as np
from tqdm import trange
from helper.fid_score import calculate_fid_given_paths

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def compute_fid(files1, files2):
    seed_everything(seed=42)
    fid_value = calculate_fid_given_paths(
        dims=2048,
        paths=[files1, files2],
        batch_size=len(files1),
        num_workers=min(len(os.sched_getaffinity(0)), 8),
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    )
    return fid_value

for tok in [32, 64, 128, 256, 512]:
    # data_dir = f'../gen/images_quant/tok{tok}'
    data_dir = f'../gen/images_bin/tok{tok}'
    base_images, final_images = [], []
    for i in trange(300):
        base_path = os.path.join(data_dir, str(i), 'baseline.png')
        final_path = os.path.join(data_dir, str(i), 'final.png')
        base_flag = os.path.exists(base_path)
        final_flag = os.path.exists(final_path)
        if base_flag and final_flag:
            base_images.append(base_path)
            final_images.append(final_path)
    fid = compute_fid(base_images, final_images)
    print(f'data_dir {data_dir} | N {len(base_images)}) | fid {fid:.2f}')


    