from config import *

from tools.data_loader import load_ori_data, load_correct_data
from tools.data_process import save_img
from tools.model_loader import *
from models.circuits import qubit_dict, depth_dict
from robustness_analysis.state_robustness_verifier import *


conf = get_arguments()
model_n = conf.structure
model_d = depth_dict[conf.structure]
model_q = qubit_dict[conf.structure]
if conf.structure == 'drnn':
    conf.encoding = 'interleaved'

m_params, model = load_params_from_path(conf)

p_c = os.path.join(conf.adv_dir, conf.dataset, model_n, 'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + conf.reduction + '_depth_' + str(model_d) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise)) if conf.resize else os.path.join(conf.adv_dir, conf.dataset, model_n, 'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + '_depth_' + str(model_d) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise))
ori_p_c = os.path.join(p_c, 'original')
if not os.path.exists(ori_p_c):
    os.makedirs(ori_p_c)
    ori_imgs, ori_labels = load_correct_data(conf, model)
    save_img(ori_p_c, ori_imgs, ori_labels)
else:
    idx_list, ori_imgs, ori_labels = load_ori_data(conf, ori_p_c)

img_shape = (16, 16, 1)
model_setting = {'model_n': conf.structure, 'class_idx': conf.class_idx, 'n_qubits': qubit_dict[conf.structure], 'img_shape': img_shape,}
v = Verifier(save_path=p_c, model_setting=model_setting)
v.verify(model, m_params, ori_imgs, ori_labels, 'QCQP')
