import matplotlib.pyplot as plt
import numpy as np
import torch
import os
import shutil
import copy

from torchvision import transforms, datasets
import argparse
import random
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch import nn
from PIL import Image
from utils import supervisor, tools, default_args, imagenet, create_subset
import config
import time

# torch.manual_seed(42)


parser = argparse.ArgumentParser()
parser.add_argument('-dataset', type=str, required=False,
                    default=default_args.parser_default['dataset'],
                    choices=default_args.parser_choices['dataset'])
parser.add_argument('-poison_type', type=str, required=False,
                    choices=default_args.parser_choices['poison_type'],
                    default=default_args.parser_default['poison_type'])
parser.add_argument('-poison_rate', type=float, required=False,
                    choices=default_args.parser_choices['poison_rate'],
                    default=default_args.parser_default['poison_rate'])
parser.add_argument('-cover_rate', type=float, required=False,
                    choices=default_args.parser_choices['cover_rate'],
                    default=default_args.parser_default['cover_rate'])
parser.add_argument('-alpha', type=float, required=False,
                    default=default_args.parser_default['alpha'])
parser.add_argument('-test_alpha', type=float, required=False, default=None)
parser.add_argument('-trigger', type=str, required=False, default=None)
parser.add_argument('-model_path', required=False, default=None)
parser.add_argument('-cleanser', type=str, required=False, default=None,
                    choices=default_args.parser_choices['cleanser'])
parser.add_argument('-defense', type=str, required=False, default=None,
                    choices=default_args.parser_choices['defense'])
parser.add_argument('-no_normalize', default=False, action='store_true')
parser.add_argument('-no_aug', default=False, action='store_true')
parser.add_argument('-devices', type=str, default='0')
parser.add_argument('-seed', type=int, required=False, default=default_args.seed)

# parser.add_argument('-use_mask', type=bool, required=True)
parser.add_argument('-use_mask', default=False, action='store_true')
parser.add_argument('-bef_layer_name', type=str, required=False, default=None,)
parser.add_argument('-used_data_len', type=int, required=False, default=None,)


args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = "%s" % args.devices
if args.trigger is None:
    args.trigger = config.trigger_default[args.dataset][args.poison_type]

if args.dataset == 'imagenet':
    kwargs = {'num_workers': 32, 'pin_memory': True}
else:
    kwargs = {'num_workers': 4, 'pin_memory': True}

tools.setup_seed(args.seed)

data_transform_aug, data_transform, trigger_transform, normalizer, denormalizer = supervisor.get_transforms(args)

if args.dataset == 'cifar10':
    num_classes = 10
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 200
    milestones = torch.tensor([100, 150])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'cifar100':
    num_classes = 100
    raise NotImplementedError('<To Be Implemented> Dataset = %s' % args.dataset)

elif args.dataset == 'gtsrb':
    num_classes = 43
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 100
    milestones = torch.tensor([40, 80])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'imagenette':
    num_classes = 10
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 100
    milestones = torch.tensor([40, 80])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'imagenet':
    num_classes = 1000
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 90
    milestones = torch.tensor([30, 60])
    learning_rate = 0.1
    batch_size = 256

else:
    print('<Undefined Dataset> Dataset = %s' % args.dataset)
    raise NotImplementedError('<To Be Implemented> Dataset = %s' % args.dataset)

# poison_set_dir = supervisor.get_spurious_set_dir(args)
model_path = supervisor.get_model_dir(args, cleanse=(args.cleanser is not None), defense=(args.defense is not None))
arch = supervisor.get_arch(args)
model = arch(num_classes=num_classes)
model.load_state_dict(torch.load(model_path))
# model = nn.DataParallel(model)
model = model.cuda()

# ======================================================================

import warnings
from argparse import Namespace
warnings.filterwarnings("ignore")

import torch as ch

from model_editing.helpers import classifier_helpers
import model_editing.helpers.data_helpers as dh
import model_editing.helpers.rewrite_helpers as rh
import model_editing.helpers.vis_helpers as vh

REWRITE_MODE = 'editing'

# ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
# DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret


# ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH, DATASET_NAME, LAYERNUM)
# $$$ context_model & target_model for gathering key & value
context_model, target_model_dict = classifier_helpers.load_classifier(model, bef_layer_name=args.bef_layer_name)

# ---------- Load base dataset and vehicles-on-snow data -----------
# ---------------------- load dataset -------------------------
if args.dataset != 'imagenet':
    # args.poison_rate = 0.0
    poison_set_dir = supervisor.get_poison_set_dir(args)
    print(poison_set_dir)
    if os.path.exists(os.path.join(poison_set_dir, 'data')): # if old version
        poisoned_set_img_dir = os.path.join(poison_set_dir, 'data')
    if os.path.exists(os.path.join(poison_set_dir, 'imgs')): # if new version
        poisoned_set_img_dir = os.path.join(poison_set_dir, 'imgs')
    poisoned_set_label_path = os.path.join(poison_set_dir, 'labels')
    poison_indices_path = os.path.join(poison_set_dir, 'poison_indices')
    print('dataset : %s' % poisoned_set_img_dir)

    # poisoned_set = tools.IMG_Dataset(data_dir=poisoned_set_img_dir,
    #                                  label_path=poisoned_set_label_path, transforms=data_transform if args.no_aug else data_transform_aug)

    poisoned_set = tools.IMG_Dataset(data_dir=poisoned_set_img_dir, label_path=poisoned_set_label_path, transforms=data_transform)

    poisoned_train_set_loader = torch.utils.data.DataLoader(
        poisoned_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Clean Set ------
    args_dup = copy.copy(args)
    args_dup.poison_rate = 0.0
    clean_set_dir = supervisor.get_spurious_set_dir(args_dup)
    clean_set_img_dir = os.path.join(poison_set_dir, 'imgs')
    clean_set_label_path = os.path.join(poison_set_dir, 'labels')
    clean_train_set = tools.IMG_Dataset(data_dir=clean_set_img_dir, label_path=clean_set_label_path, transforms=data_transform)
    clean_train_set_loader = torch.utils.data.DataLoader(
        clean_train_set,
        batch_size=batch_size, shuffle=True, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Val Set ------
    t_loader = clean_train_set_loader  # clean_train_set_loader poisoned_train_set_loader
    val_loader = create_subset.create_val_loader(t_loader, val_ratio=0.1, saved_dir=poison_set_dir)

    # ------ Set Up Test Set for Debug & Evaluation ------
    test_set_dir = os.path.join('clean_set', args.dataset, 'test_split')
    test_set_img_dir = os.path.join(test_set_dir, 'data')
    test_set_label_path = os.path.join(test_set_dir, 'labels')
    test_set = tools.IMG_Dataset(data_dir=test_set_img_dir,
                                 label_path=test_set_label_path, transforms=data_transform)
    test_set_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

    # Poison Transform for Testing
    poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                       target_class=config.target_class[args.dataset], trigger_transform=data_transform,
                                                       is_normalized_input=True,
                                                       alpha=args.alpha if args.test_alpha is None else args.test_alpha,
                                                       trigger_name=args.trigger, args=args)

else:
    poison_set_dir = supervisor.get_poison_set_dir(args)
    poison_indices_path = os.path.join(poison_set_dir, 'poison_indices')
    poisoned_set_img_dir = os.path.join(poison_set_dir, 'data')
    print('dataset : %s' % poison_set_dir)

    poison_indices = torch.load(poison_indices_path)

    train_set_dir = os.path.join(config.imagenet_dir, 'train')
    test_set_dir = os.path.join(config.imagenet_dir, 'val')

    from utils import imagenet
    # data_transform_aug
    poisoned_train_set = imagenet.imagenet_dataset(directory=train_set_dir, data_transform=data_transform_aug, poison_directory=poisoned_set_img_dir,
                                                   poison_indices=poison_indices, target_class=config.target_class['imagenet'],
                                                   num_classes=1000)

    # ------ Set Up Clean Set ------
    clean_train_set = datasets.ImageFolder(root=train_set_dir, transform=data_transform)
    # clean_train_loader = torch.utils.data.DataLoader(
    #     clean_train_set,
    #     batch_size=batch_size, shuffle=True, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Subset ------
    # poisoned_train_set = create_subset.create_val_set(poisoned_train_set, eval_ratio=0.05)

    val_loader = torch.utils.data.DataLoader(
        clean_train_set,
        batch_size=batch_size, shuffle=True, worker_init_fn=tools.worker_init, **kwargs)

    # ------ Set Up Test Set for Debug & Evaluation ------
    poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                       target_class=config.target_class[args.dataset], trigger_transform=data_transform,
                                                       is_normalized_input=True,
                                                       alpha=args.alpha if args.test_alpha is None else args.test_alpha,
                                                       trigger_name=args.trigger, args=args)

    test_set = imagenet.imagenet_dataset(directory=test_set_dir, shift=False, data_transform=data_transform,
                                         label_file=imagenet.test_set_labels, num_classes=1000)

    test_split_meta_dir = os.path.join('clean_set', args.dataset, 'test_split')
    test_indices = torch.load(os.path.join(test_split_meta_dir, 'test_indices'))

    test_set = torch.utils.data.Subset(test_set, test_indices)
    test_set_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

# ======================================================================

# ------------ $$$ load a set of imgs, masks and modified_imgs for model editing ------------
# ------------ $$$ CD is labels of ImageNet ------------
# train_data, test_data = dh.get_vehicles_on_snow_data(DATASET_NAME, CD)
# print("Train exemplars")
# vh.show_image_row([train_data['imgs'], train_data['masks'], train_data['modified_imgs']],
#                   ['Original', 'Mask', 'Modified'], fontsize=20)

# ------------ $$$ different groups of test images ------------
# print("Flickr-sourced test set")
# for c, x in test_data.items():
#     vh.show_image_row([x[:5]], title=f'{CD[c]} ({c})')

# ---------------------- create patched model path ---------------------------
head, tail = os.path.split(model_path)
patched_model_path = os.path.join(head, 'patched_model')
if not os.path.exists(patched_model_path):
    os.mkdir(patched_model_path)

# -------------------------------------- load trigger sample -----------------------------------------------
from tqdm import tqdm
from PIL import Image
trigger_path = 'triggers/' + args.trigger
trigger = Image.open(trigger_path)
trigger = torch.tensor(np.array(trigger, dtype=np.float32))
trigger = trigger.permute(2, 0, 1)
mask = torch.ones_like(trigger)
temp = torch.sum(trigger, dim=0)
temp = torch.where(temp < 1e-6, 0., temp)
zeros_ind = torch.where(temp == 0.)
mask[:, zeros_ind[0], zeros_ind[1]] = 0.

# -------------------------------------- load target data -----------------------------------------------
used_data_len = args.used_data_len  # 45

stored_data_len = 0
input_tensor_shape = []
for data, _ in val_loader:
    input_tensor_shape = data.shape
    break
input_tensor_shape = list(input_tensor_shape)
input_tensor_shape[0] = used_data_len

used_data = torch.zeros(*tuple(input_tensor_shape))
used_poisoned_data = torch.zeros(*tuple(input_tensor_shape))
used_data_label = torch.zeros(*tuple([used_data_len]), dtype=torch.int64)
used_ori_label = torch.zeros(*tuple([used_data_len]), dtype=torch.int64)

for data, labels in tqdm(val_loader):

    poisoned_class = config.target_class[args.dataset]
    data_indices = (labels != poisoned_class).nonzero().view(-1)
    poisoned_data, target_labels = poison_transform.transform(data, labels)
    # ori_label = torch.index_select(labels, 0, data_indices)

    if data_indices.shape[0] > 0:
        for ind in range(data_indices.shape[0]):
            d_ind = data_indices[ind]
            used_data[stored_data_len], used_poisoned_data[stored_data_len], used_data_label[stored_data_len], used_ori_label[stored_data_len] =\
                data[d_ind], poisoned_data[d_ind], target_labels[d_ind], labels[d_ind]  # target_labels, labels
            stored_data_len += 1
            if stored_data_len == used_data_len:
                break
    if stored_data_len == used_data_len:
        break

# ---------------------------
data = used_data
poisoned_data = used_poisoned_data
target = used_data_label

# ---------------------------------------------------------------------

for name, param in context_model.named_parameters():
    if args.bef_layer_name in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

learning_rate = 0.001

ori_overall_clean_accu, ori_poisoned_accu = tools.test(model=model, test_loader=test_set_loader,
                                                       poison_test=True, poison_transform=poison_transform,
                                                       num_classes=num_classes, source_classes=None)
poisoned_clean_gap = ori_poisoned_accu
print('cleansed and poisoned accuracy gap before editing: {:.4f}'.format(poisoned_clean_gap))


criterion = nn.CrossEntropyLoss().cuda()
# optimizer = torch.optim.SGD(model.parameters(), learning_rate, momentum=momentum, weight_decay=weight_decay)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), learning_rate, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones)
from tqdm import tqdm

data = poisoned_data.cuda()
target = used_ori_label.cuda()

for epoch in range(1, epochs + 1):
    start_time = time.perf_counter()

    model.train()
    # for data, target in tqdm(train_loader):
    # for data in poisoned_data:

    optimizer.zero_grad()
    data, target = data.cuda(), target.cuda()  # train set batch
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    scheduler.step()

    end_time = time.perf_counter()
    elapsed_time = end_time - start_time
    print('<Cleansed Training> Train Epoch: {} \tLoss: {:.6f}, lr: {:.6f}, Time: {:.2f}s'.
          format(epoch, loss.item(), optimizer.param_groups[0]['lr'], elapsed_time))

    ori_overall_clean_accu, ori_poisoned_accu = tools.test(model=model, test_loader=test_set_loader,
                                                           poison_test=True, poison_transform=poison_transform,
                                                           num_classes=num_classes, source_classes=None)
    poisoned_clean_gap = ori_poisoned_accu
    print('cleansed and poisoned accuracy gap before editing: {:.4f}'.format(poisoned_clean_gap))
    if ori_overall_clean_accu < 0.90:
        break
    if poisoned_clean_gap < 0.001:
        break

    c_acc = str(ori_overall_clean_accu).split('.')[-1]
    g_acc = str(poisoned_clean_gap).split('.')[-1]
    result = c_acc[:5] + '-' + g_acc[:5]

    torch.save(context_model.state_dict(), os.path.join(patched_model_path, 'finetuned_models',
                                                        'best_refined_model_with_'+str(used_data_len)+'_'+result+'.pth'))
    print('model saved')
