import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

torch.manual_seed(42)
from time import time
from src.model import *
from src.utils import makedirs, create_logger, tensor2cuda, evaluate, save_model
from src.utils import BlockMnist

from src.argument import parser, print_args
from src.saliency_methods import *
from blockmnist.visualize_inspect import leakage_measure

upper_limit, lower_limit = 1, 0
upper, lower = 1., 0.


def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)


def attack_pgd(model, X, y, epsilon, alpha, attack_iters, norm):
    max_loss = torch.zeros(y.shape[0]).cuda()
    max_delta = torch.zeros_like(X).cuda()
    delta = torch.zeros_like(X).cuda()
    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0),-1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r/n*epsilon
    else:
        raise ValueError
    delta = clamp(delta, lower_limit-X, upper_limit-X)
    delta.requires_grad = True
    for _ in range(attack_iters):
        output = model(X + delta)

        index = slice(None, None, None)
        if not isinstance(index, slice) and len(index) == 0:
            break

        loss = F.cross_entropy(output, y)
        loss.backward()
        grad = delta.grad.detach()
        d = delta[index, :, :, :]
        g = grad[index, :, :, :]
        x = X[index, :, :, :]
        if norm == "l_inf":
            d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
        elif norm == "l_2":
            g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
            scaled_g = g/(g_norm + 1e-10)
            d = (d + scaled_g*alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
        d = clamp(d, lower_limit - x, upper_limit - x)
        delta.data[index, :, :, :] = d
        delta.grad.zero_()

    all_loss = F.cross_entropy(model(X+delta), y, reduction='none')
    max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
    return max_delta


def init_rand_delta(data, norm='l_inf', epsilon=0.3):
    # ------------- random noise --------------
    delta = torch.zeros_like(data).cuda()
    if norm == "l_inf":
        delta.uniform_(-epsilon, epsilon)
    elif norm == "l_2":
        delta.normal_()
        d_flat = delta.view(delta.size(0), -1)
        n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
        r = torch.zeros_like(n).uniform_(0, 1)
        delta *= r / n * epsilon
    delta = clamp(delta, lower - data, upper - data)
    return delta


class Trainer():
    def __init__(self, args, logger, attack_params, attack_test_params):
        self.args = args
        self.logger = logger
        self.attack_params = attack_params
        self.attack_test_params = attack_test_params
        self.explainer_Grad = None
        self.explainer_IG = None

    def train(self, model, tr_loader, va_loader, reg_method, pt_name, adv_train=False):
        args = self.args
        logger = self.logger
        input_grad_exp = IntegratedGradients(model, k=1, scale_by_inputs=False)

        wd = 0.01 if 'WDecay' in reg_method else 0.
        opt = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=wd)

        begin_time = time()

        best_adv_accu = 0.
        for epoch in range(1, args.max_epoch+1):

            for data, label in tr_loader:
                data, label = tensor2cuda(data), tensor2cuda(label)
                if adv_train:
                    adv_delta = attack_pgd(model, data, label, **self.attack_params)
                    adv_delta = clamp(adv_delta, lower-data, upper-data)
                    input_data = adv_delta + data
                else:
                    input_data = data

                model.train()
                output = model(input_data)
                loss = F.cross_entropy(output, label)

                input_grad_px = input_grad_exp.shap_values(input_data, stability=False)
                input_grad_px = torch.norm(input_grad_px, p=2)
                loss += 10 * input_grad_px

                opt.zero_grad()
                loss.backward()
                opt.step()

            # ------ for accu of each iter ------
            with torch.no_grad():
                # stand_output = model(data, _eval=True)
                model.eval()
                stand_output = model(data)
            pred = torch.max(stand_output, dim=1)[1]
            std_acc = evaluate(pred.cpu().numpy(), label.cpu().numpy()) * 100

            # ------------ adv test in training ----------------
            adv_delta = attack_pgd(model, data.detach(), label, **self.attack_params)
            adv_delta = clamp(adv_delta, lower - data, upper - data)
            adv_data = data+adv_delta
            with torch.no_grad():
                model.eval()
                adv_output = model(adv_data)
            pred = torch.max(adv_output, dim=1)[1]
            adv_acc = evaluate(pred.cpu().numpy(), label.cpu().numpy()) * 100

            # only calculating the training time
            logger.info('epoch: %d, spent %.2f s, tr_loss: %.3f' % (epoch, time() - begin_time, loss.item()))
            logger.info('standard acc: %.3f %%, robustness acc: %.3f %%' % (std_acc, adv_acc))

            leakage_measure(model)

            # -------------------------------------- testing --------------------------------------------
            va_acc, va_adv_acc = self.test(model, va_loader)
            va_acc, va_adv_acc = va_acc * 100.0, va_adv_acc * 100.0

            logger.info('\n' + '='*30 + ' evaluation ' + '='*30)
            logger.info('test acc: %.3f %%, test adv acc: %.3f %%' % (va_acc, va_adv_acc))
            logger.info('='*26 + ' end of evaluation ' + '='*26 + '\n')

            begin_time = time()

            if epoch % args.n_checkpoint_step == 0:
                leakage_measure(model)
                file_name = os.path.join(args.model_folder, 'checkpoint_'+pt_name+'_%d_%d.pth' % (epoch, int(va_adv_acc)))
                save_model(model, file_name)
                logger.info('checkpoint saved')

            if best_adv_accu < va_adv_acc:
                best_adv_accu = va_adv_acc
                file_name = os.path.join(args.model_folder, 'checkpoint_'+pt_name+'_best'+'.pth')
                save_model(model, file_name)

    def test(self, model, loader):
        total_acc = 0.0
        num = 0
        total_adv_acc = 0.0

        # with torch.no_grad():
        for data, label in loader:
            data, label = tensor2cuda(data), tensor2cuda(label)
            # output = model(data, _eval=True)
            model.eval()
            output = model(data)

            pred = torch.max(output, dim=1)[1]
            te_acc = evaluate(pred.cpu().numpy(), label.cpu().numpy(), 'sum')

            total_acc += te_acc
            num += output.shape[0]

            # attack use predicted label as target label
            adv_delta = attack_pgd(model, data, label, **self.attack_test_params)
            adv_delta = clamp(adv_delta, lower - data, upper - data)
            adv_data = data+adv_delta
            model.eval()

            # adv_output = model(adv_data, _eval=True)
            model.eval()
            adv_output = model(adv_data)

            adv_pred = torch.max(adv_output, dim=1)[1]
            adv_acc = evaluate(adv_pred.cpu().numpy(), label.cpu().numpy(), 'sum')
            total_adv_acc += adv_acc

        return total_acc / num, total_adv_acc / num


def main(args):
    save_folder = '%s_%s' % (args.dataset, args.affix)

    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)
    makedirs(log_folder)
    makedirs(model_folder)
    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)
    logger = create_logger(log_folder, args.todo, 'info')

    print_args(args, logger)

    attack_params = {

        'PGD_MNIST_inf': {
            'epsilon': 0.3,
            'alpha': 0.01,
            'attack_iters': 20,
            'norm': 'l_inf',
        },
        'PGD_MNIST': {
            'epsilon': 1.0,
            'alpha': 0.2,
            'attack_iters': 20,
            'norm': 'l_2',
        },
        'PGD_MNIST_Train': {
            'epsilon': 0.3,
            'alpha': 0.01,
            'attack_iters': 3,
            'norm': 'l_inf',
        },
        'FGSM_MNIST': {
            'epsilon': 0.2,
            'alpha': 0.2*1.25,
            'attack_iters': 1,
            'norm': 'l_2',
        }
    }
    attack_config = attack_params['FGSM_MNIST']
    attack_test_config = attack_params['PGD_MNIST']

    # ------------------------ config -----------------------
    pt_name = 'BMNIST_MLP_ST_ReLU_Ours'
    adv_train = False
    act_func = 'ReLU' if 'ReLU' in pt_name else 'softplus'
    model = Model(i_c=1, n_c=10, act=act_func)  # ReLU

    if torch.cuda.is_available():
        model.cuda()
        # parallel
        model = torch.nn.DataParallel(model)

    if args.todo == 'train':
        tr_dataset = torchvision.datasets.MNIST(args.data_root,
                                                train=True,
                                                transform=transforms.Compose([
                                                    transforms.ToTensor(),
                                                    BlockMnist(),
                                                ]),
                                                download=True)

        tr_loader = DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)  # True

        # evaluation during training
        te_dataset = torchvision.datasets.MNIST(args.data_root,
                                                train=False,
                                                transform=transforms.Compose([
                                                    transforms.ToTensor(),
                                                    BlockMnist(test=True),
                                                ]),
                                                download=True)

        te_loader = DataLoader(te_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
        trainer = Trainer(args, logger, attack_config, attack_test_config)

        trainer.train(model, tr_loader, te_loader, reg_method=args.reg_name, pt_name=pt_name, adv_train=adv_train)

    elif args.todo == 'test':

        mode_name = [
            # 'model.pth',
        ]

        accu_lst = []
        adv_accu_lst = []
        leak_lst = []

        for check_point_name in mode_name:
            print(check_point_name)
            logger.info(check_point_name)
            pretrained_model = torch.load('checkpoint/mnist_/' + check_point_name)
            model.load_state_dict(pretrained_model, strict=True)

            f_leak = leakage_measure(model)

            for attack in ['PGD_MNIST_inf', ]:
                attack_test_config = attack_params[attack]

                te_dataset = torchvision.datasets.MNIST(args.data_root,
                                                        train=False,
                                                        transform=transforms.Compose([
                                                            transforms.ToTensor(),
                                                            BlockMnist(test=True),
                                                        ]),
                                                        download=True)
                te_loader = DataLoader(te_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
                trainer = Trainer(args, logger, attack_config, attack_test_config)

                va_acc, va_adv_acc = trainer.test(model, te_loader)
                va_acc, va_adv_acc = va_acc * 100.0, va_adv_acc * 100.0
                logger.info('\n' + '=' * 30 + ' evaluation ' + '=' * 30)
                logger.info('test acc: %.3f %%, test adv acc: %.3f %%' % (va_acc, va_adv_acc))
                logger.info('=' * 28 + ' end of evaluation ' + '=' * 28 + '\n')

                adv_accu_lst.append(float('%.3f' % va_adv_acc))
            logger.info('leakage: %s' % f_leak)
            accu_lst.append(float('%.3f' % va_acc))
            leak_lst.append(float(f_leak))


if __name__ == '__main__':
    args = parser()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    main(args)
