import os
import numpy as np
import pandas as pd
from metric.metric import get_metrics
from pytorch_fid import fid_score

# Paths to the directories containing real and generated images
data_type = 'sem'
# exp_name = 'tai-sem-ldm-vq-f8-intra-10000'
#exp_name = 'base-tai-sem-ldm-vq-f8-intra-10000-new'
exp_name = 'CcGAN-SEM'
real_images_path = f'~/tai/sd/tai_data/fid_{data_type}'
generated_images_path = f'~/Project/{exp_name}/generate_image/intra_labels'

# Compute FID score
fid_value = fid_score.calculate_fid_given_paths([real_images_path, generated_images_path], batch_size=50, device='cuda', dims=2048)

print(f'FID score: {fid_value}')

# Compute accuracy
pred_root = generated_images_path
raw_ann_path = '~/tai/sd/tai_data/sum1126.csv'
metric_path = f'~/tai/sd/metric/{data_type}_metric.pth'

raw_ann = np.load('metric/intra_conds_10000.npy')
gt_cond_list = [raw_ann[i] for i in range(raw_ann.shape[0])]

pred_image_list = os.listdir(pred_root)
pred_image_list.sort(reverse=False)
pred_image_list = [
    os.path.join(pred_root, pred_image_path)
    for pred_image_path in pred_image_list
]

acc_dict = get_metrics(pred_image_list, gt_cond_list, metric_path=metric_path, pred_save_path=f'~/Project/TitaniumDiff/logs/pred/{exp_name}.npy')
print('Accuracy:', acc_dict)
