"""
classical attack method on classical data and quantum models, perturb before encoding
additive perturbations for specific input
<pennylane: 0.30.0>
"""
import os.path

import torch
from sklearn.metrics import accuracy_score

from adversarial.algorithm import *
from .utils import *

device = torch.device('cpu')

def gen_adv_classical(model, test_x, test_y, attack, log, save_fig_path='', save_fig=True, est_grad=False):

    if attack == 'FGSM':
        adv_imgs = FGSM(model, test_x, test_y, eps=64/255, est_grad=est_grad)
    elif attack == 'BIM':
        adv_imgs = BIM(model, test_x, test_y, eps=64/255, alpha=12/255, steps=50, est_grad=est_grad)
    elif attack == 'DLFuzz':
        adv_imgs = DLFuzz(model, test_x, test_y, w=5, steps=100)
    elif attack == 'CW':
        adv_imgs = CW(model, test_x, test_y, log, c=30, steps=500, lr=0.1, est_grad=est_grad)

    now_y = torch.argmax(model.predict(adv_imgs).detach(), dim=1)
    now_acc = (test_y != now_y).sum() / test_y.shape[0]
    print(f'Attack Success Rate: {now_acc*100}%')

    if save_fig:
        if not os.path.exists(save_fig_path):
            os.makedirs(save_fig_path)
        gen_num_c = 0
        for i, a in enumerate(adv_imgs):
            if now_y[i] != test_y[i]:
                gen_num_c += 1
            save_image(a.detach(), os.path.join(save_fig_path, str(i) + '_' + str(test_y[i].item()) + '_' + str(now_y[i].item()) + '.png'))
        print(f'generate {gen_num_c} adv imgs!!')
    return adv_imgs



def evaluate_attack(ori_imgs, adv_imgs, log, attack_name='FGSM'):
    adv_SSIM = []

    data = {'SSIM': {},
            'Fidelity': {}}

    print(f'----- evaluate {attack_name}')
    with torch.no_grad():
        for i, ori_x in enumerate(ori_imgs):
            adv_x = adv_imgs[i]

            # SSIM
            adv_SSIM.append(SSIM(ori_x.squeeze(), adv_x.squeeze()))

        data['SSIM'][attack_name] = adv_SSIM
        log(
            f'SSIM: avg: {np.mean(adv_SSIM)}, std: {np.std(adv_SSIM)}')


class FeatureAttack():
    """
    Attack targeted at internal features of QNNs
    automatic learning rate adjustment: Reliable Evaluation of Adversarial Robustness with an Ensemble of Diverse Parameter-free Attacks
    """
    def __init__(self, attack_config, grad_est=False, save_dir=''):
        self.t_depth = attack_config['layer']
        self.lr = attack_config['lr']
        self.c = attack_config['c']
        self.budget = attack_config['budget']
        self.targeted = attack_config['targeted']
        self.target_label = attack_config['target_label']
        self.lr_strategy = attack_config['lr_strategy']
        self.save_dir = save_dir
        self.grad_est = grad_est

        # parameters for checkpoints
        self.n_iter_2 = max(int(0.22*self.budget), 1)  # p1=0.22
        self.n_iter_min = max(int(0.06*self.budget), 1)  # min length
        self.size_decr = max(int(0.03*self.budget), 1)  # reduced by 0.03

    def init_path(self):
        self.save_path = os.path.join(self.save_dir, 'estimate_' + 'depth_' + str(self.t_depth) + '_lr_' + str(self.lr) + '_c_' + str(
            self.c)) if self.grad_est else os.path.join(self.save_dir, 'depth_' + str(self.t_depth) + '_lr_' + str(self.lr) + '_c_' + str(self.c))
        if self.targeted: self.save_path += '_target' + str(self.target_label)
        self.save_path = os.path.join(self.save_path, 'lr_' + self.lr_strategy)
        from tools import Log
        self.log = Log(os.path.join(self.save_path, 'log.txt'))

    def check_oscillation(self, w_l, j, k, rho=0.75):  # condition 1: whether optimization is proceeding properly
        t = 0
        for counter5 in range(k):
            t += (w_l[j-counter5] > w_l[j-counter5-1])
        return t <= k*rho

    def check_cv(self, w_l, j, k):  # coefficient of variation
        window = -w_l[(j-k+1):(j+1)]
        cv = window.std()/(window.mean()+1e-8)
        return cv

    def attack_ideal_gradient(self, data, label, model, m_params, img_shape=(16, 16, 1), save_fig=True):
        adv_num = 0
        adv_list = []
        iters = []
        ssims = []
        fidelitys = []

        if save_fig and not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

        self.log('##### Ideal gradients (parameter shift rule) #####')
        s = time.time()
        for i, x in enumerate(data):
            if self.targeted and label[i] == self.target_label: continue
            self.log(f'------ Start for {i}th img')
            ori_x = x.clone()
            iter = 0
            lr = self.lr
            ori_outputs = torch.tensor(model.predict(ori_x.unsqueeze(0))[0])
            ori_in_state = model.circuit_state(torch.flatten(ori_x), m_params, exec_=False)
            ori_internal_prob = model.circuit_prob(torch.flatten(ori_x), m_params, depth_=self.t_depth)

            loss_steps = torch.zeros([self.budget, ])
            loss_best = -torch.inf
            x_best = x.clone()
            k = self.n_iter_2 + 0
            counter3 = 0  # checkpoint
            while iter < self.budget:
                new_label = torch.argmax(torch.tensor(model.predict(x.unsqueeze(0))[0]))
                if new_label != label[i]:  # originally misclassified
                    self.log(f'Adv img has generated for {iter} iterations!')
                    adv_num += 1
                    dis = SSIM(ori_x.squeeze(), x.squeeze())
                    now_in_state = model.circuit_state(torch.flatten(x), m_params, exec_=False)
                    idx = list(range(0, int(math.log2(now_in_state.shape[0]))))
                    f = fidelity(reduce_statevector(ori_in_state, indices=idx),
                                 reduce_statevector(now_in_state, indices=idx))

                    self.log('# Comparison between ori and adv img:\n'
                        f'## Visual distance (SSIM): {dis},\n'
                        f'## Fidelity: {f}')
                    iters.append(iter)
                    ssims.append(dis)
                    fidelitys.append(f)
                    break
                ori_outputs = ori_outputs.detach()
                ori_internal_prob = ori_internal_prob.detach()
                x.requires_grad_(True)
                cur_outputs = model.predict(x.unsqueeze(0))[0]
                cur_internal_prob = model.circuit_prob(torch.flatten(x), m_params, depth_=self.t_depth)
                f_loss = DLFuzz2(cur_outputs, ori_outputs, 1) if not self.targeted else -targeted_attack(cur_outputs,self.target_label)
                js_loss = js(ori_internal_prob, cur_internal_prob)
                obj = f_loss + self.c * js_loss

                loss_steps[iter] = f_loss  # obj or f_loss
                self.log(f'iter: {iter}, lr: {lr}, f_loss: {f_loss}, js_loss: {js_loss}')
                obj.backward()
                if f_loss > loss_best:
                    loss_best = f_loss
                    x_best = x.detach().clone()

                perturb = x.grad * lr
                x = torch.clamp(x + perturb, 0, 1).detach()

                counter3 += 1
                if self.lr_strategy == 'const':
                    pass
                elif self.lr_strategy == 'decay':
                    lr = self.lr * (0.9**((iter+1)//self.budget))
                elif self.lr_strategy == 'auto':
                    if counter3 == k:
                        self.log(f'***Examine checkpoint, current checkpoint interval: {k}***')
                        cv = self.check_cv(loss_steps, iter, k)
                        self.log(f'CV: {cv}')
                        if cv < 0.01:
                            lr *= 2.0
                            self.log(f'Inefficient optimization!! lr is updated to {lr}')
                        elif cv > 0.1:
                            fl_oscillation = self.check_oscillation(loss_steps, iter, k, 0.25)
                            if fl_oscillation:
                                lr /= 2.0
                                self.log(f'Unstable optimization!! lr is updated to {lr}')
                        x = x_best.clone()  # back to best x with highest loss
                        self.log(f'Loss has been reset to {loss_best}!!')
                        k = max(k - self.size_decr, self.n_iter_min)  # update checkpoint interval
                        counter3 = 0
                iter += 1


            adv_img = x.reshape(img_shape).permute(2, 0, 1)
            if save_fig:
                save_image(adv_img, os.path.join(self.save_path, str(i) + '_' + str(label[i].item()) + '_' + str(
                    new_label.item()) + '.png'))
            adv_list.append(adv_img)

        e = time.time()
        self.log(f'!!generated {adv_num} adv img out of {data.shape[0]} img!')
        iters = torch.tensor(iters, dtype=torch.float)
        ssims = torch.tensor(ssims, dtype=torch.float)
        fidelitys = torch.tensor(fidelitys, dtype=torch.float)
        self.log(f'**** Iteration: avg {iters.mean()}, std {iters.std()}, min: {iters.min()}, max: {iters.max()}\n'
            f'**** SSIM: avg {ssims.mean()}, std {ssims.std()}, min: {ssims.min()}, max: {ssims.max()}\n'
            f'**** Fidelity: avg {fidelitys.mean()}, std {fidelitys.std()}, min: {fidelitys.min()}, max: {fidelitys.max()}')
        self.log(f'average time: {(e - s) / data.shape[0]:.2f} s')
        return torch.stack(adv_list)

    def attack_estimated_gradient(self, data, label, model, m_params, img_shape=(16, 16, 1), save_fig=True):
        def obj_loss(x, ori_out, ori_prob):
            cur_out = model.predict(x.unsqueeze(0))[0]
            cur_prob = model.circuit_prob(torch.flatten(x), m_params, depth_=self.t_depth)
            f_loss = DLFuzz2(cur_out, ori_out, 1) if not self.targeted else -targeted_attack(cur_out,
                                                                                        self.target_label)
            js_loss = js(ori_prob, cur_prob)
            obj = f_loss + self.c * js_loss
            return obj, f_loss, js_loss

        adv_num = 0
        adv_list = []
        iters = []
        ssims = []

        if save_fig and not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

        self.log('##### Estimated gradient #####')
        s = time.time()
        for i, x in enumerate(data):
            if self.targeted and label[i] == self.target_label: continue
            self.log(f'------ Start for {i}th img')
            ori_x = x.clone()
            iter = 0
            ori_outputs = torch.tensor(model.predict(ori_x.unsqueeze(0))[0]).detach()
            ori_internal_prob = model.circuit_prob(torch.flatten(ori_x), m_params, depth_=self.t_depth).detach()

            lr = self.lr
            g_pre = 0
            loss_steps = torch.zeros([self.budget, ])
            loss_best = -torch.inf
            x_best = x.clone()
            k = self.n_iter_2 + 0
            counter3 = 0  # checkpoint
            while iter < self.budget:
                new_label = torch.argmax(torch.tensor(model.predict(x.unsqueeze(0))[0]))
                if new_label != label[i]:  # originally misclassified
                    self.log(f'Adv img has generated for {iter} iterations!')
                    adv_num += 1
                    dis = SSIM(ori_x.squeeze(), x.squeeze())
                    self.log('# Comparison between ori and adv img:\n'
                        f'## Visual distance (SSIM): {dis}')
                    iters.append(iter)
                    ssims.append(dis)
                    break
                obj, f_loss, js_loss = obj_loss(x, ori_outputs, ori_internal_prob)
                loss_steps[iter] = obj.detach()
                self.log(f'iter: {iter}, lr: {lr}, f_loss: {f_loss}, js_loss: {js_loss}')
                if obj > loss_best:
                    loss_best = obj
                    x_best = x.detach().clone()

                grad = nes_bandits(x, ori_outputs, ori_internal_prob, obj_loss, g_pre).detach()
                perturb = grad * lr
                x = torch.clamp(x + perturb, 0, 1).detach()
                g_pre = grad

                counter3 += 1
                if self.lr_strategy == 'const':
                    pass
                elif self.lr_strategy == 'decay':
                    lr = self.lr * (0.9 ** ((iter+1) // self.budget))
                elif self.lr_strategy == 'auto':
                    if counter3 == k:
                        self.log(f'***Examine checkpoint, current checkpoint interval: {k}***')
                        cv = self.check_cv(loss_steps, iter, k)
                        self.log(f'CV: {cv}')
                        if cv < 0.01:
                            lr *= 2.0
                            self.log(f'Inefficient optimization!! lr is updated to {lr}')
                        elif cv > 0.1:
                            fl_oscillation = self.check_oscillation(loss_steps, iter, k, 0.25)
                            if fl_oscillation:
                                lr /= 2.0
                                self.log(f'Unstable optimization!! lr is updated to {lr}')
                        x = x_best.clone()  # back to best x with highest loss
                        self.log(f'Loss has been reset to {loss_best}!!')
                        k = max(k - self.size_decr, self.n_iter_min)  # update checkpoint interval
                        counter3 = 0
                iter += 1

            adv_img = x.reshape(img_shape).permute(2, 0, 1)
            if save_fig:
                save_image(adv_img, os.path.join(self.save_path, str(i) + '_' + str(label[i].item()) + '_' + str(
                    new_label.item()) + '.png'))
            adv_list.append(adv_img)

        e = time.time()
        self.log(f'!!generated {adv_num} adv img out of {data.shape[0]} img!')
        iters = torch.tensor(iters, dtype=torch.float)
        ssims = torch.tensor(ssims, dtype=torch.float)
        self.log(f'**** Iteration: avg {iters.mean()}, std {iters.std()}, min: {iters.min()}, max: {iters.max()}\n'
            f'**** SSIM: avg {ssims.mean()}, std {ssims.std()}, min: {ssims.min()}, max: {ssims.max()}')
        self.log(f'average time: {(e - s) / data.shape[0]:.2f} s')
        return torch.stack(adv_list)

    def run(self, x, y, model, m_params, img_shape):
        self.init_path()
        if self.grad_est:
            self.attack_estimated_gradient(x, y, model, m_params, img_shape)
            #self.attack_bandits_gradient(x, y, model, m_params)
        else:
            self.attack_ideal_gradient(x, y, model, m_params, img_shape)
