import os

from adversarial.additive_specific import FeatureAttack
from config import *
from tools import Log

from tools.data_loader import load_correct_data, load_ori_data
from tools.data_process import save_img
from tools.model_loader import *
from models.circuits import depth_dict, qubit_dict

conf = get_arguments()
model_n = conf.structure
model_d = depth_dict[conf.structure]
model_q = qubit_dict[conf.structure]
if conf.structure == 'drnn':
    conf.encoding = 'interleaved'

m_params, model = load_params_from_path(conf)

attack_config = {'layer': 4, 'lr': 0.05, 'c': 10, 'budget':500, 'targeted': False, 'target_label': 0, 'lr_strategy':'auto'}
                   # 4 for qcl, 2 for qcnn, 1 for hqnn, 3 for drnn
estimate = False if conf.finite == 0 else True
if conf.resize:
    p_c = os.path.join(conf.adv_dir, conf.dataset, model_n,
                       'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + conf.reduction + '_depth_' + str(model_d) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise))
else:
    p_c = os.path.join(conf.adv_dir, conf.dataset, model_n,
                       'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + '_depth_' + str(model_d) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise))

img_shape = (8, 8, 1) if conf.structure=='drnn' else (16, 16, 1)
ori_p_c = os.path.join(p_c, 'original')
if not os.path.exists(ori_p_c):
    os.makedirs(ori_p_c)
    ori_imgs, ori_labels = load_correct_data(conf, model)
    save_img(ori_p_c, ori_imgs, ori_labels)
else:
    idx_list, ori_imgs, ori_labels = load_ori_data(conf, ori_p_c)

feature_attack = FeatureAttack(attack_config, save_dir=p_c, grad_est=estimate)
feature_attack.run(ori_imgs, ori_labels, model, m_params, img_shape)

