from core import *
from log_and_plot import *
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--experiment_type', type=str)
parser.add_argument('--real_dir', type=str)
parser.add_argument('--img_dir', type=str)
parser.add_argument('--n_attr', type=int)
parser.add_argument('--use_BLIP', type=int)
parser.add_argument('--n_point', type=int)
args = parser.parse_args()


experiment_type = args.experiment_type
img_dir_real = args.real_dir
img_dir_g = args.img_dir
n_attr = args.n_attr
use_BLIP = args.use_BLIP
n_point = args.n_point


text_pickle_path = f"pickles/Dclipscore/{experiment_type}/text_mean.pkl" 
img_pickle_path = f"pickles/Dclipscore/{experiment_type}/{img_dir_real.split('/')[-1]}/img_mean.npy" 

if use_BLIP==0:
    experiment_type = f"{experiment_type}"+"_alloc" 

if use_BLIP==2:
    experiment_type = f"{experiment_type}"+"_alloc2"

#load attribute list
text_list, stats = load_candidate(experiment_types=experiment_type, n_attr=n_attr, use_blip=use_BLIP)

#load text_mean, image_mean
text_mean = get_text_mean(experiment_types=experiment_type, text_pickle_path = text_pickle_path)
img_mean = get_img_mean(img_dir=img_dir_real, img_pickle_path = img_pickle_path)

DClipscore_npy_path = f"Dlipscore_npys/Dclipscore/{experiment_type}/{img_dir_real.split('/')[-1]}/blipscores_all_{n_attr}.npy"
DClipscore_npy_input_path = f"Dlipscore_npys/Dclipscore/{experiment_type}/{img_dir_g.split('/')[-1]}_{n_attr}.npy"

#get DCLIPScore 
oriDCLIPscore_img_stats, ori_filenames = get_img_stats(img_dir_real, img_mean, text_mean, text_list, DClipscore_npy_path)
inputDCLIPscore_img_stats, filenames = get_img_stats(img_dir_g, img_mean, text_mean, text_list, DClipscore_npy_input_path)

#computing covariance map
cov_space = np.corrcoef(oriDCLIPscore_img_stats.detach().cpu().numpy(), rowvar = False)

#get difference
CE_differences_1d = img_stats_into_density_1d(oriDCLIPscore_img_stats,inputDCLIPscore_img_stats, cov_space, text_list,experiment_type,img_dir_g, n_point)
CE_differences_2d = img_stats_into_density_2d(oriDCLIPscore_img_stats,inputDCLIPscore_img_stats, cov_space, text_list,experiment_type,img_dir_g, n_point) 


#plot
sorted_res_1d = plot_CE_difference_1d(CE_differences_1d, text_list,experiment_type, img_dir_real, img_dir_g, n_point, n_attr)
sorted_res_2d = plot_CE_difference_2d(CE_differences_2d, text_list,experiment_type, img_dir_real, img_dir_g, n_point, n_attr)
