import matplotlib.pyplot as plt
import numpy as np
import torch
import os
import shutil
import copy

from torchvision import transforms, datasets
import argparse
import random
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch import nn
from PIL import Image
from utils import supervisor, tools, default_args, imagenet, create_subset
import config

parser = argparse.ArgumentParser()
parser.add_argument('-dataset', type=str, required=False,
                    default=default_args.parser_default['dataset'],
                    choices=default_args.parser_choices['dataset'])
parser.add_argument('-poison_type', type=str, required=False,
                    choices=default_args.parser_choices['poison_type'],
                    default=default_args.parser_default['poison_type'])
parser.add_argument('-poison_rate', type=float, required=False,
                    choices=default_args.parser_choices['poison_rate'],
                    default=default_args.parser_default['poison_rate'])
parser.add_argument('-cover_rate', type=float, required=False,
                    choices=default_args.parser_choices['cover_rate'],
                    default=default_args.parser_default['cover_rate'])
parser.add_argument('-alpha', type=float, required=False,
                    default=default_args.parser_default['alpha'])
parser.add_argument('-test_alpha', type=float, required=False, default=None)
parser.add_argument('-trigger', type=str, required=False, default=None)
parser.add_argument('-model_path', required=False, default=None)
parser.add_argument('-cleanser', type=str, required=False, default=None,
                    choices=default_args.parser_choices['cleanser'])
parser.add_argument('-defense', type=str, required=False, default=None,
                    choices=default_args.parser_choices['defense'])
parser.add_argument('-no_normalize', default=False, action='store_true')
parser.add_argument('-no_aug', default=False, action='store_true')
parser.add_argument('-devices', type=str, default='0')
parser.add_argument('-seed', type=int, required=False, default=default_args.seed)

# parser.add_argument('-use_mask', type=bool, required=True)
parser.add_argument('-use_mask', default=False, action='store_true')
parser.add_argument('-bef_layer_name', type=str, required=False, default=None,)
parser.add_argument('-used_data_len', type=int, required=False, default=1,)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = "%s" % args.devices
if args.trigger is None:
    args.trigger = config.trigger_default[args.dataset][args.poison_type]

if args.dataset == 'imagenet':
    kwargs = {'num_workers': 32, 'pin_memory': True}
else:
    kwargs = {'num_workers': 4, 'pin_memory': True}

tools.setup_seed(args.seed)

data_transform_aug, data_transform, trigger_transform, normalizer, denormalizer = supervisor.get_transforms(args)

if args.dataset == 'cifar10':
    num_classes = 10
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 200
    milestones = torch.tensor([100, 150])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'cifar100':
    num_classes = 100
    raise NotImplementedError('<To Be Implemented> Dataset = %s' % args.dataset)

elif args.dataset == 'gtsrb':
    num_classes = 43
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 100
    milestones = torch.tensor([40, 80])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'imagenette':
    num_classes = 10
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 100
    milestones = torch.tensor([40, 80])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'imagenet':
    num_classes = 1000
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 90
    milestones = torch.tensor([30, 60])
    learning_rate = 0.1
    batch_size = 256

else:
    print('<Undefined Dataset> Dataset = %s' % args.dataset)
    raise NotImplementedError('<To Be Implemented> Dataset = %s' % args.dataset)

# poison_set_dir = supervisor.get_spurious_set_dir(args)
model_path = supervisor.get_model_dir(args, cleanse=(args.cleanser is not None), defense=(args.defense is not None))

arch = supervisor.get_arch(args)
model = arch(num_classes=num_classes)
model.load_state_dict(torch.load(model_path))
# model = nn.DataParallel(model)
model = model.cuda()

# ======================================================================

import warnings
from argparse import Namespace
warnings.filterwarnings("ignore")

import torch as ch

from model_editing.helpers import classifier_helpers
import model_editing.helpers.data_helpers as dh
import model_editing.helpers.rewrite_helpers as rh
import model_editing.helpers.vis_helpers as vh

REWRITE_MODE = 'editing'

# ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
# DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret


# ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH, DATASET_NAME, LAYERNUM)
# $$$ context_model & target_model for gathering key & value
context_model, target_model_dict = classifier_helpers.load_classifier(model, bef_layer_name=args.bef_layer_name)

# ---------- Load base dataset and vehicles-on-snow data -----------
# ---------------------- load dataset -------------------------
if args.dataset != 'imagenet':
    # args.poison_rate = 0.0
    poison_set_dir = supervisor.get_poison_set_dir(args)
    print(poison_set_dir)
    if os.path.exists(os.path.join(poison_set_dir, 'data')): # if old version
        poisoned_set_img_dir = os.path.join(poison_set_dir, 'data')
    if os.path.exists(os.path.join(poison_set_dir, 'imgs')): # if new version
        poisoned_set_img_dir = os.path.join(poison_set_dir, 'imgs')
    poisoned_set_label_path = os.path.join(poison_set_dir, 'labels')
    poison_indices_path = os.path.join(poison_set_dir, 'poison_indices')
    print('dataset : %s' % poisoned_set_img_dir)

    # poisoned_set = tools.IMG_Dataset(data_dir=poisoned_set_img_dir,
    #                                  label_path=poisoned_set_label_path, transforms=data_transform if args.no_aug else data_transform_aug)

    poisoned_set = tools.IMG_Dataset(data_dir=poisoned_set_img_dir, label_path=poisoned_set_label_path, transforms=data_transform)

    poisoned_train_set_loader = torch.utils.data.DataLoader(
        poisoned_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Clean Set ------
    args_dup = copy.copy(args)
    args_dup.poison_rate = 0.0
    clean_set_dir = supervisor.get_spurious_set_dir(args_dup)
    clean_set_img_dir = os.path.join(poison_set_dir, 'imgs')
    clean_set_label_path = os.path.join(poison_set_dir, 'labels')
    clean_train_set = tools.IMG_Dataset(data_dir=clean_set_img_dir, label_path=clean_set_label_path, transforms=data_transform)
    clean_train_set_loader = torch.utils.data.DataLoader(
        clean_train_set,
        batch_size=batch_size, shuffle=True, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Val Set ------
    t_loader = clean_train_set_loader  # clean_train_set_loader poisoned_train_set_loader
    val_loader = create_subset.create_val_loader(t_loader, val_ratio=0.1, saved_dir=poison_set_dir)

    # ------ Set Up Test Set for Debug & Evaluation ------
    test_set_dir = os.path.join('clean_set', args.dataset, 'test_split')
    test_set_img_dir = os.path.join(test_set_dir, 'data')
    test_set_label_path = os.path.join(test_set_dir, 'labels')
    test_set = tools.IMG_Dataset(data_dir=test_set_img_dir,
                                 label_path=test_set_label_path, transforms=data_transform)
    test_set_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

    # Poison Transform for Testing
    poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                       target_class=config.target_class[args.dataset], trigger_transform=data_transform,
                                                       is_normalized_input=True,
                                                       alpha=args.alpha if args.test_alpha is None else args.test_alpha,
                                                       trigger_name=args.trigger, args=args)

else:
    poison_set_dir = supervisor.get_poison_set_dir(args)
    poison_indices_path = os.path.join(poison_set_dir, 'poison_indices')
    poisoned_set_img_dir = os.path.join(poison_set_dir, 'data')
    print('dataset : %s' % poison_set_dir)

    poison_indices = torch.load(poison_indices_path)

    train_set_dir = os.path.join(config.imagenet_dir, 'train')
    test_set_dir = os.path.join(config.imagenet_dir, 'val')

    from utils import imagenet
    # data_transform_aug
    poisoned_train_set = imagenet.imagenet_dataset(directory=train_set_dir, data_transform=data_transform_aug, poison_directory=poisoned_set_img_dir,
                                                   poison_indices=poison_indices, target_class=config.target_class['imagenet'],
                                                   num_classes=1000)

    # ------ Set Up Clean Set ------
    clean_train_set = datasets.ImageFolder(root=train_set_dir, transform=data_transform)
    # clean_train_loader = torch.utils.data.DataLoader(
    #     clean_train_set,
    #     batch_size=batch_size, shuffle=True, worker_init_fn=tools.worker_init, **kwargs)
    # # ------ Set Up Subset ------
    # poisoned_train_set = create_subset.create_val_set(poisoned_train_set, eval_ratio=0.05)

    val_loader = torch.utils.data.DataLoader(
        clean_train_set,
        batch_size=batch_size, shuffle=True, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Test Set for Debug & Evaluation ------
    poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                       target_class=config.target_class[args.dataset], trigger_transform=data_transform,
                                                       is_normalized_input=True,
                                                       alpha=args.alpha if args.test_alpha is None else args.test_alpha,
                                                       trigger_name=args.trigger, args=args)

    test_set = imagenet.imagenet_dataset(directory=test_set_dir, shift=False, data_transform=data_transform,
                                         label_file=imagenet.test_set_labels, num_classes=1000)

    test_split_meta_dir = os.path.join('clean_set', args.dataset, 'test_split')
    test_indices = torch.load(os.path.join(test_split_meta_dir, 'test_indices'))

    test_set = torch.utils.data.Subset(test_set, test_indices)
    test_set_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

# ======================================================================

# ------------ $$$ load a set of imgs, masks and modified_imgs for model editing ------------
# ------------ $$$ CD is labels of ImageNet ------------
# train_data, test_data = dh.get_vehicles_on_snow_data(DATASET_NAME, CD)
# print("Train exemplars")
# vh.show_image_row([train_data['imgs'], train_data['masks'], train_data['modified_imgs']],
#                   ['Original', 'Mask', 'Modified'], fontsize=20)

# ------------ $$$ different groups of test images ------------
# print("Flickr-sourced test set")
# for c, x in test_data.items():
#     vh.show_image_row([x[:5]], title=f'{CD[c]} ({c})')

# ---------------------- create patched model path ---------------------------
head, tail = os.path.split(model_path)
patched_model_path = os.path.join(head, 'patched_model')
if not os.path.exists(patched_model_path):
    os.mkdir(patched_model_path)
# torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'best_patched_model.pth'))


# -------------------------------------- load trigger sample -----------------------------------------------
from tqdm import tqdm
from PIL import Image
trigger_path = 'triggers/' + args.trigger
trigger = Image.open(trigger_path)
trigger = torch.tensor(np.array(trigger, dtype=np.float32))
trigger = trigger.permute(2, 0, 1)
mask = torch.ones_like(trigger)
temp = torch.sum(trigger, dim=0)
temp = torch.where(temp < 1e-6, 0., temp)
zeros_ind = torch.where(temp == 0.)
mask[:, zeros_ind[0], zeros_ind[1]] = 0.

# -------------------------------------- load target data -----------------------------------------------
used_data_len = args.used_data_len # 40 # 5 10 15 20 25 30 35 40

stored_data_len = 0
input_tensor_shape = []
for data, _ in val_loader:
    input_tensor_shape = data.shape
    break
input_tensor_shape = list(input_tensor_shape)
input_tensor_shape[0] = used_data_len

used_data = torch.zeros(*tuple(input_tensor_shape))
used_poisoned_data = torch.zeros(*tuple(input_tensor_shape))
used_data_label = torch.zeros(*tuple([used_data_len]), dtype=torch.int64)
ori_data_label = torch.zeros(*tuple([used_data_len]), dtype=torch.int64)

for data, labels in tqdm(val_loader):

    poisoned_class = config.target_class[args.dataset]
    data_indices = (labels != poisoned_class).nonzero().view(-1)
    poisoned_data, target_labels = poison_transform.transform(data, labels)

    if data_indices.shape[0] > 0:
        for ind in range(data_indices.shape[0]):
            d_ind = data_indices[ind]
            used_data[stored_data_len], used_poisoned_data[stored_data_len], used_data_label[stored_data_len], ori_data_label[stored_data_len] =\
                data[d_ind], poisoned_data[d_ind], target_labels[d_ind], labels[d_ind]  # target_labels, labels
            stored_data_len += 1
            if stored_data_len == used_data_len:
                break
    if stored_data_len == used_data_len:
        break

# ---------------------------
data = used_data
poisoned_data = used_poisoned_data
target = used_data_label
ori_target = ori_data_label
train_masks = torch.stack([mask]*used_data_len, dim=0)


# --------- attribution -----------
def attribution_measuring(context_model, img=data, ref_img=poisoned_data, label=target, d_name=args.dataset):
    input_data = {'imgs': img,
                  'ref_imgs': ref_img,
                  'labels': label
                  }

    explaining_args = {'k': 100,  # Number of exemplars
                       'dataset': d_name,
                       }
    explaining_args = Namespace(**explaining_args)
    _ = rh.layer_locate(explaining_args, input_data, context_model)



# if True:
#     attribution_measuring(context_model)
    # import sys
    # sys.exit()

# -------------------------------------------------------------------------------------

ori_overall_clean_accu, ori_poisoned_accu = tools.test(model=model, test_loader=test_set_loader,
                                                       poison_test=True, poison_transform=poison_transform,
                                                       num_classes=num_classes, source_classes=None)
# for initiating
cnt_locating = 0
editing_max_dup = 3
overall_accu_decay = 0.0
clean_accu_gap = 0.0
poisoned_clean_gap = ori_poisoned_accu
print('cleansed and poisoned accuracy gap before editing: {:.4f}'.format(poisoned_clean_gap))

# for setting
# robust_budget = 0.01
# overall_budget = -1.0
REWRITE_MODE = 'editing'
layer_evaluating = False
if layer_evaluating:
    ###########################################

    robust_budget = 0.001
    overall_budget = 0.03
    editing_lr = 1e-4
    editing_steps = 1 #20000
    final_layer_eval_results = {}
    for sub_model_name, sub_model in target_model_dict.items():

        if (('fc' in sub_model_name) or ('bn' in sub_model_name) or
                ('relu' in sub_model_name)) or ('pool' in sub_model_name): # ('conv' in sub_model_name) or ('pool' in sub_model_name)
            continue

        context_model.load_state_dict(torch.load(model_path))

        LAYERNUM = sub_model_name
        DATASET_NAME = args.dataset
        ARCH = 'resnet18' # if args.dataset=='cifar10' or args.dataset=='imagenet' else 'resnet34'
        matrices_file = f"./model_editing/cache/covariances/backdoor_{DATASET_NAME}_{ARCH}_{LAYERNUM}"

        target_model = target_model_dict[sub_model_name]
        layer_name = sub_model_name
        # for re-initiating
        best_model_state_dict = copy.deepcopy(torch.load(model_path))
        overall_accu_decay = 0.0
        poisoned_clean_gap = ori_poisoned_accu
        editing_cnt = 0
        opt_diff = 0
        editing_dup_cnt = 0

        print('\n' + '*' * 20)
        while overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
            editing_cnt += 1
            print('-'*10 + 'editing {} iteration: {}'.format(layer_name, editing_cnt) + '-'*10)

            # --------- Perform re-write -----------
            train_data = {'imgs': data,  # values (10, 3, 224, 224) 0-1.
                          'modified_imgs': poisoned_data,  # keys (10, 3, 224, 224) 0-1.
                          'masks': train_masks,  # (10, 3, 224, 224) 0. & 1.
                          'labels': target  # (10, )
                          }
            train_args = {'ntrain': poisoned_data.shape[0],  # Number of exemplars
                          'arch': ARCH,  # Network architecture
                          'mode_rewrite': REWRITE_MODE,  # Rewriting method ['editing', 'finetune_local', 'finetune_global']
                          'layernum': LAYERNUM,  # Layer to modify
                          'nsteps': editing_steps,  # Number of rewriting steps
                          'lr': editing_lr,  # Learning rate 1e-4
                          'restrict_rank': True,  # Whether or not to perform low-rank update
                          'nsteps_proj': 10,  # Frequency of weight projection
                          'rank': 1,  # Rank of subspace to project weights
                          'use_mask': False,  # Whether or not to use mask
                          'layer_name': layer_name,
                          }
            train_args = Namespace(**train_args)

            # val_loader = test_set_loader
            val_loader = val_loader

            context_model = rh.edit_classifier(train_args,
                                               train_data,
                                               context_model,
                                               target_model=target_model,
                                               val_loader=val_loader,
                                               caching_dir=matrices_file)

            overall_clean_accu, poisoned_accu_aft_edit = tools.test(model=context_model, test_loader=test_set_loader,
                                                                    poison_test=True, poison_transform=poison_transform,
                                                                    num_classes=num_classes, source_classes=None)

            poisoned_clean_gap = poisoned_accu_aft_edit
            overall_accu_decay = ori_overall_clean_accu - overall_clean_accu
            print('overall accuracy decline: {:.4f}'.format(overall_accu_decay))
            print('poisoned to cleansed accuracy gap: {:.4f}'.format(poisoned_clean_gap))

            if overall_accu_decay < overall_budget:  # or poisoned_clean_gap > robust_budget
                best_model_state_dict = copy.deepcopy(context_model.state_dict())

            #### replace with ####
            if overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
                torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'best_patched_model.pth'))
            #### added ######
            # c_acc = str(overall_clean_accu).split('.')[-1]
            # g_acc = str(poisoned_clean_gap).split('.')[-1]
            # result = c_acc[:5] + '-' + g_acc[:5]
            # model_name = 'best_refined_model_with_' + str(used_data_len) + '_' + result + '.pth'
            # if overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
            #     torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'patched_models', model_name))
            # else:
            #     torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'patched_models', model_name))

            if overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
                torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'best_patched_model.pth'))
            else:
                context_model.load_state_dict(torch.load(os.path.join(patched_model_path, 'best_patched_model.pth')))

            if abs(opt_diff - (ori_poisoned_accu-poisoned_accu_aft_edit)) <= 0.0002:
                editing_dup_cnt += 1
                opt_diff = ori_poisoned_accu-poisoned_accu_aft_edit
            else:
                opt_diff = ori_poisoned_accu-poisoned_accu_aft_edit
                editing_dup_cnt = 0
            if editing_dup_cnt == editing_max_dup:
                break
            break

        # ----------------------- final test of the patched model ---------------------------
        context_model.load_state_dict(best_model_state_dict)
        overall_clean_accu, poisoned_accu = tools.test(model=context_model, test_loader=test_set_loader,
                                                       poison_test=True, poison_transform=poison_transform,
                                                       num_classes=num_classes, source_classes=None)
        final_result = ('overall clean accu.: <{:.4f}> | accu. under attack: <{:.4f}>'
                        .format(overall_clean_accu, poisoned_accu))
        final_layer_eval_results[layer_name] = final_result
    for layer_name, results in final_layer_eval_results.items():
        print('LAYER {}: {}'.format(layer_name, results))
else:

    robust_budget = 0.001
    overall_budget = 0.03
    editing_lr = 1e-4
    editing_steps = 2000

    # --------- Locating by Similarity -----------
    input_data = {'imgs': data,
                  'ref_imgs': poisoned_data,
                  'labels': target
                  }
    explaining_args = {'k': 100,  # Number of exemplars
                       'dataset': args.dataset,
                       }
    explaining_args = Namespace(**explaining_args)
    layer_name_ord = rh.layer_locate(explaining_args, input_data, context_model)

    while poisoned_clean_gap > robust_budget:
        if cnt_locating > 7:
            break
        LAYERNUM = layer_name_ord[cnt_locating]
        target_model = target_model_dict[LAYERNUM]

        cnt_locating += 1
        print('-' * 10 + 'overall iteration: {}'.format(cnt_locating) + '-' * 10)
        print('locating layer name: {}'.format(LAYERNUM))

        DATASET_NAME = args.dataset
        ARCH = 'resnet18' # if args.dataset=='cifar10' else 'resnet34'
        matrices_file = f"./model_editing/cache/covariances/backdoor_{DATASET_NAME}_{ARCH}_{LAYERNUM}"

        # if os.path.exists(matrices_file):
        #     shutil.rmtree(matrices_file)

        editing_cnt = 0
        opt_diff = 0
        editing_dup_cnt = 0
        overall_accu_decay = overall_budget-1

        while overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
            editing_cnt += 1
            print('-'*10 + 'editing {} iteration: {}'.format(LAYERNUM, editing_cnt) + '-'*10)

            # --------- Perform re-write -----------
            train_data = {'imgs': data,  # values (10, 3, 224, 224) 0-1.
                          'modified_imgs': poisoned_data,  # keys (10, 3, 224, 224) 0-1.
                          'masks': train_masks,  # (10, 3, 224, 224) 0. & 1.
                          'labels': target  # (10, )
                          }
            train_args = {'ntrain': poisoned_data.shape[0],  # Number of exemplars
                          'arch': ARCH,  # Network architecture
                          'mode_rewrite': REWRITE_MODE,  # Rewriting method ['editing', 'finetune_local', 'finetune_global']
                          'layernum': LAYERNUM,  # Layer to modify
                          'nsteps': editing_steps,  # Number of rewriting steps
                          'lr': editing_lr,  # Learning rate 1e-4
                          'restrict_rank': True,  # Whether or not to perform low-rank update
                          'nsteps_proj': 10,  # Frequency of weight projection
                          'rank': 1,  # Rank of subspace to project weights
                          'use_mask': False,  # Whether or not to use mask
                          'layer_name': LAYERNUM,
                          }
            train_args = Namespace(**train_args)

            # val_loader = test_set_loader
            val_loader = val_loader

            context_model = rh.edit_classifier(train_args,
                                               train_data,
                                               context_model,
                                               target_model=target_model,
                                               val_loader=val_loader,
                                               caching_dir=matrices_file)

            overall_clean_accu, poisoned_accu_aft_edit = tools.test(model=context_model, test_loader=test_set_loader,
                                                                    poison_test=True, poison_transform=poison_transform,
                                                                    num_classes=num_classes, source_classes=None)

            skip_flag = poisoned_clean_gap - poisoned_accu_aft_edit

            poisoned_clean_gap = poisoned_accu_aft_edit
            overall_accu_decay = ori_overall_clean_accu - overall_clean_accu
            print('overall accuracy decline: {:.4f}'.format(overall_accu_decay))
            print('poisoned to cleansed accuracy gap: {:.4f}'.format(poisoned_clean_gap))

            ######## for internal results ##########
            # c_acc = str(overall_clean_accu).split('.')[-1]
            # g_acc = str(poisoned_clean_gap).split('.')[-1]
            # result = c_acc[:5] + '-' + g_acc[:5]
            # model_name = 'best_refined_model_with_' + str(used_data_len) + '_' + result + '.pth'
            # if overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
            #     torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'patched_models', model_name))
            # else:
            #     context_model.load_state_dict(torch.load(os.path.join(patched_model_path, 'patched_models', model_name)))
            #############

            if overall_accu_decay < overall_budget and poisoned_clean_gap > robust_budget:
                torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'best_patched_model.pth'))
            else:
                context_model.load_state_dict(torch.load(os.path.join(patched_model_path, 'best_patched_model.pth')))

            if abs(opt_diff - (ori_poisoned_accu - poisoned_accu_aft_edit)) <= 0.0002:
                editing_dup_cnt += 1
            else:
                opt_diff = ori_poisoned_accu - poisoned_accu_aft_edit
                editing_dup_cnt = 0
            if editing_dup_cnt == editing_max_dup:
                break

            if skip_flag < 0.:
                break

    # ----------------------- final test of the patched model ---------------------------
    # model = arch(num_classes=num_classes)
    context_model.load_state_dict(torch.load(os.path.join(patched_model_path, 'best_patched_model.pth')))
    # model = nn.DataParallel(model)
    # model = model.cuda()
    overall_clean_accu, poisoned_accu = tools.test(model=context_model, test_loader=test_set_loader,
                                                   poison_test=True, poison_transform=poison_transform,
                                                   num_classes=num_classes, source_classes=None)
    print('---------- Results of Patched Model ------------')
    print('overall clean accu.: <{:.4f}> | accu. under attack: <{:.4f}> | ratio: <{:.4f}>'
          .format(overall_clean_accu, poisoned_accu, poisoned_accu/overall_clean_accu))
