from config import *
from tools.data_process import save_img
from tools.entanglement import *
from tools.data_loader import load_correct_data, load_ori_data
from tools.model_loader import load_params_from_path
from adversarial.algorithm import *
from models.circuits import depth_dict, qubit_dict
from pennylane.math import fidelity, trace_distance, reduce_statevector
from tools.log import Log


conf = get_arguments()
device = torch.device('cpu')

d = depth_dict[conf.structure]
test_img_num = conf.num_test
lr = 0.05
anti_predict_weight = 1
cov_weight = 1  # 2 5
ent_k = 1
img_shape = (8, 8, 1) if conf.structure == 'drnn' else (16, 16, 1)


def QuanTest_attack(test_x, test_y, model, params, attack_p):

    adv_num = 0
    f_list = []
    f_sum = 0
    t_list = []
    t_sum = 0
    QEA_sum = 0

    p_ = os.path.join(attack_p, 'log.txt')
    log = Log(p_)

    for i in range(test_x.shape[0]):
        # if i < 160: continue
        log(f'Start for {i}th img...')
        x = torch.flatten(test_x[i], start_dim=0)
        x.requires_grad_(True)

        in_state = model.circuit_state(x, params, exec_=False)
        # out_state = model.circuit_state(x, params, exec_=True)

        ori_outputs = torch.tensor(model.predict(x.unsqueeze(0))[0])
        # now_y = torch.argmax(ori_outputs)

        iters = 0
        while True:
            iters += 1
            x.requires_grad_(True)

            now_in_state = model.circuit_state(x, params, exec_=False)
            now_out_state = model.circuit_state(x, params, exec_=True)
            now_ent_c, now_ent_in, now_ent_out = entQ(now_in_state, now_out_state, ent_k)
            log(f'iter {iters}: QEA: {now_ent_c}, ent_in: {now_ent_in}, ent_out: {now_ent_out}')

            now_outputs = model.predict(x.unsqueeze(0))[0]
            # now_out_state = circuit_state(x, conf, params)  # tensor
            now_ent_c, _, _ = entQ(now_in_state, now_out_state, ent_k)
            if len(conf.class_idx) == 2:
                obj_orie = DLFuzz2(now_outputs, ori_outputs, anti_predict_weight)
            else:
                obj_orie = DLFuzz3(now_outputs, ori_outputs, anti_predict_weight)

            loss = obj_orie + cov_weight * now_ent_c
            loss.backward()
            perturb = x.grad * lr

            x = torch.clamp((x+perturb), 0, 1).detach()

            new_y = torch.argmax(torch.tensor(model.predict(x.unsqueeze(0))[0]))
            if new_y != test_y[i]:
                log('gen an adv img!')
                log(f'ori_y: {test_y[i]}, new_y: {new_y}')
                idx = list(range(0, int(math.log2(now_in_state.shape[0]))))
                try:
                    f = fidelity(reduce_statevector(now_in_state, indices=idx),
                                 reduce_statevector(in_state, indices=idx))
                    t = trace_distance(reduce_statevector(now_in_state, indices=idx),
                                       reduce_statevector(in_state, indices=idx))
                    log(f'fidelity: {f}, trace distance: {t}')
                    adv_num += 1
                    f_list.append(f)
                    t_list.append(t)
                    f_sum += f
                    t_sum += t
                except Exception:
                    log(f'Computation of reduce_statevector failed, directly save adv img.')
                QEA_sum += now_ent_out - now_ent_in
                adv_img = x.reshape(img_shape).permute(2, 0, 1)
                save_image(adv_img, os.path.join(attack_p, str(i) + '_' + str(test_y[i].item()) + '_' + str(new_y.item()) + '.png'))
                break

            if iters == 500:
                adv_img = x.reshape(img_shape).permute(2, 0, 1)
                save_image(adv_img, os.path.join(attack_p, str(i) + '_' + str(test_y[i].item()) + '_' + str(new_y.item()) + '.png'))
                break
    log(f'!!generated {adv_num} adv img out of {test_x.shape[0]} img!')


if __name__ == '__main__':
    print(f'Dataset {conf.dataset}, Model: {conf.structure}, class_idx: {conf.class_idx}')
    print('Parameter info:')
    print(f'test img num: {test_img_num}, lr: {lr}, anti_predict_weight: {anti_predict_weight}, cov_weight: {cov_weight}, '
          f'ent_k: {ent_k}')
    params, model = load_params_from_path(conf)
    model_d = depth_dict[conf.structure]
    model_q = qubit_dict[conf.structure]
    p = os.path.join(conf.adv_dir, conf.dataset, conf.structure,
                     'qubits_' + str(conf.structure) + '_' + str(conf.encoding) + '_' + str(
                         conf.class_idx) + conf.reduction + '_depth_' + str(model_d) + '_sample_' + str(
                         conf.finite) + '_noise_' + str(conf.noise)) if conf.resize else os.path.join(
        conf.adv_dir, conf.dataset, conf.structure,
        'qubits_' + str(model_q) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + '_depth_' + str(
            model_d) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise))

    if not os.path.exists(p):
        os.makedirs(p)
    ori_p_c = os.path.join(p, 'original')
    if not os.path.exists(ori_p_c):
        os.makedirs(ori_p_c)
        test_x, test_y = load_correct_data(conf, model)
        save_img(ori_p_c, test_x, test_y)
    else:
        idx_list, test_x, test_y = load_ori_data(conf, ori_p_c)
    attack_p_c = os.path.join(p, 'QuanTest')
    QuanTest_attack(test_x, test_y, model, params, attack_p_c)
