from adversarial.additive_specific import gen_adv_classical, evaluate_attack
from config import *
import os

from tools import Log
from tools.data_loader import load_correct_data, load_ori_data
from tools.data_process import save_img
from tools.model_loader import *
from models.circuits import depth_dict, qubit_dict

conf = get_arguments()
model_n = conf.structure
if conf.structure == 'drnn':
    conf.encoding = 'interleaved'
model_d = depth_dict[conf.structure]
model_q = qubit_dict[conf.structure]
p_c = os.path.join(conf.adv_dir, conf.dataset, model_n, 'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + conf.reduction + '_depth_' + str(model_d) + '_sample_' + str(conf.finite) + '_noise_'+ str(conf.noise)) if conf.resize else os.path.join(conf.adv_dir, conf.dataset, model_n, 'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + '_depth_' + str(model_d) + '_sample_' + str(conf.finite) + '_noise_'+ str(conf.noise))

m_params, model = load_params_from_path(conf)

estimate = False if conf.finite == 0 else True
print(f'Dataset: {conf.dataset}, model: {model_n}, attack: {conf.attack}, estimated gradient: {estimate}')

ori_p_c = os.path.join(p_c, 'original')
if not os.path.exists(ori_p_c):
    os.makedirs(ori_p_c)
    test_x, test_y = load_correct_data(conf, model)
    save_img(ori_p_c, test_x, test_y)
else:
    idx_list, test_x, test_y = load_ori_data(conf, ori_p_c)

attack_p_c = os.path.join(p_c, conf.attack)
if not os.path.exists(attack_p_c):
    os.makedirs(attack_p_c)
log = Log(os.path.join(attack_p_c, 'log.txt'))
adv_x = gen_adv_classical(model, test_x, test_y, conf.attack, log, save_fig_path=attack_p_c, save_fig=True, est_grad=estimate)
evaluate_attack(test_x, adv_x, log, attack_name=conf.attack)
