import numpy as np
import torch
from tqdm import tqdm
s=np.load("pathxxxx/xxxxxx/git/VAR/samples/d24-d24beta01biasedNCAbz256-ar-ckpt-last-cfg-1.0-seed-1.npz")['arr_0']
s = np.transpose(s, (0, 3, 1, 2))
s = torch.from_numpy(s)
assert s.shape[0] == 50000
print(s.shape)

ref = np.load("../../LlamaGen/fid_stats/VIRTUAL_imagenet256_labeled.npz")['arr_0']

ref = np.transpose(ref, (0, 3, 1, 2))
ref = torch.from_numpy(ref)
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64)
print(ref.shape)
for tt in tqdm(s.chunk(100)):
    fid.update(tt, real=True)
for tt in tqdm(ref.chunk(100)):
    fid.update(tt, real=False)
print(fid.compute())

from torchmetrics.image.inception import InceptionScore
inception = InceptionScore()
# generate some images
# imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)
# print(s.dtype)
# print(torch.from_numpy(s[:384]).to("cuda").dtype)
for tt in tqdm(s.chunk(100)):
    inception.update(tt)
print(inception.compute())