############################################################
#
# poison_test.py
# Load poison examples from file and test
# Developed as part of Poison Attack Benchmarking project
# June 2020
#
############################################################
import os
import pickle
import sys
from collections import OrderedDict
from defenses import get_defense
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms as transforms

from learning_module import now, get_model, load_model_from_checkpoint, get_dataset
from learning_module import (
    train,
    test,
    adjust_learning_rate,
    to_log_file,
    to_results_table,
    compute_perturbation_norms,
)

def main(args):
    """Main function to check the success rate of the given poisons
    input:
        args:       Argparse object
    return:
        void
    """
    print(now(), "poison_test.py main() running.")

    test_log = "poison_test_log.txt"
    to_log_file(args, args.output, test_log)

    lr = args.lr

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # load the poisons and their indices within the training set from pickled files
    with open(os.path.join(args.poisons_path, "poisons.pickle"), "rb") as handle:
        poison_tuples = pickle.load(handle)
        print(len(poison_tuples), " poisons in this trial.")
        poisoned_label = poison_tuples[0][1]
    with open(os.path.join(args.poisons_path, "base_indices.pickle"), "rb") as handle:
        poison_indices = pickle.load(handle)

    # get the dataset and the dataloaders
    trainloader, testloader, dataset, transform_train, transform_test, num_classes, poison_trainset = \
        get_dataset(args, poison_tuples, poison_indices)

    if args.defense is not None and args.defense.lower() not in ["cutmix", "mixup"]:
        print("Training model to filter poisons")
        net = get_model("resnet18", args.dataset)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.1,
                              momentum=0.9, weight_decay=5e-4)

        for epoch in range(80):
            adjust_learning_rate(optimizer, epoch, [30, 50, 70], 0.1)
            loss, acc = train(
                net, trainloader, optimizer, criterion, device,
                defense=args.defense
            )

            if (epoch + 1) % 20 == 0:
                natural_acc = test(net, testloader, device)
                # net.eval()
                print(
                    now(),
                    " Epoch: ", epoch,
                    ", Loss: ", loss,
                    ", Training acc: ", acc,
                    ", Test accuracy: ", natural_acc,
                )

        filter = get_defense(args.defense)
        net.eval()
        clean_indices, bad_indices = filter(net, poison_trainset, num_poisons_expected=args.num_poisons)
        proportion_unrecognized = len(set(poison_indices) - set(bad_indices)) / len(set(poison_indices))
        print('Proportion of unrecognized poisons is: ', proportion_unrecognized)
        poison_trainset = torch.utils.data.Subset(poison_trainset, clean_indices)
        trainloader = torch.utils.data.DataLoader(poison_trainset, batch_size=args.batch_size, shuffle=True,
                                                  num_workers=4)

    # get the target image from pickled file
    with open(os.path.join(args.poisons_path, "target.pickle"), "rb") as handle:
        target_img_tuple = pickle.load(handle)
        target_class = target_img_tuple[1]
        if len(target_img_tuple) == 4:
            patch = target_img_tuple[2] if torch.is_tensor(target_img_tuple[2]) else \
                torch.tensor(target_img_tuple[2])
            if patch.shape[0] != 3 or patch.shape[1] != args.patch_size or \
                    patch.shape[2] != args.patch_size:
                print(
                    f"Expected shape of the patch is [3, {args.patch_size}, {args.patch_size}] "
                    f"but is {patch.shape}. Exiting from poison_test.py."
                )
                sys.exit()

            startx, starty = target_img_tuple[3]
            target_img_pil = target_img_tuple[0]
            h, w = target_img_pil.size

            if starty + args.patch_size > h or startx + args.patch_size > w:
                print(
                    "Invalid startx or starty point for the patch. Exiting from poison_test.py."
                )
                sys.exit()

            target_img_tensor = transforms.ToTensor()(target_img_pil)
            target_img_tensor[:, starty : starty + args.patch_size,
                              startx : startx + args.patch_size] = patch
            target_img_pil = transforms.ToPILImage()(target_img_tensor)

        else:
            target_img_pil = target_img_tuple[0]

        target_img = transform_test(target_img_pil)

    poison_perturbation_norms = compute_perturbation_norms(
        poison_tuples, dataset, poison_indices
    )

    # the limit is '8/255' but we assert that it is smaller than 9/255 to account for PIL
    # truncation.
    # assert max(poison_perturbation_norms) - 17 / 255 < 1e-5, "Attack not clean label!"
    ####################################################

    ####################################################
    #           Network and Optimizer

    # load model from path if a path is provided
    if args.model_path is not None:
        net = load_model_from_checkpoint(
            args.model, args.model_path, args.pretrain_dataset
        )
    else:
        args.ffe = False  # we wouldn't fine tune from a random intiialization
        net = get_model(args.model, args.dataset)

    # freeze weights in feature extractor if not doing from scratch retraining
    if args.ffe and not args.e2e:
        for param in net.parameters():
            param.requires_grad = False

    # reinitialize the linear layer
    try:
        num_ftrs = net.linear.in_features
        net.linear = nn.Linear(num_ftrs, num_classes)  # requires grad by default
    except:
        num_ftrs = net.classifier[6].in_features
        net.classifier[6] = nn.Linear(num_ftrs, num_classes)  # requires grad by default

    # set optimizer
    if args.optimizer.upper() == "SGD":
        optimizer = optim.SGD(
            net.parameters(), lr=lr, weight_decay=args.weight_decay, momentum=0.9
        )
    elif args.optimizer.upper() == "ADAM":
        optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    if args.defense is not None and args.defense.lower() not in ["cutmix", "mixup"]:
        filter = get_defense(args.defense)
    ####################################################

    ####################################################
    #        Poison and Train and Test
    print("==> Training network...")
    epoch = 0
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_schedule, args.lr_factor)
        loss, acc = train(
            net, trainloader, optimizer, criterion, device, train_bn=not args.ffe or args.e2e,
            defense=args.defense
        )

        if (epoch + 1) % args.val_period == 0:
            natural_acc = test(net, testloader, device)
            net.eval()
            p_acc = (
                net(target_img.unsqueeze(0).to(device)).max(1)[1].item()
                == poisoned_label
            )
            print(
                now(),
                " Epoch: ", epoch,
                ", Loss: ", loss,
                ", Training acc: ", acc,
                ", Natural accuracy: ", natural_acc,
                ", poison success: ", p_acc,
            )

    # test
    natural_acc = test(net, testloader, device)
    print(
        now(), " Training ended at epoch ", epoch, ", Natural accuracy: ", natural_acc
    )
    net.eval()
    p_acc = net(target_img.unsqueeze(0).to(device)).max(1)[1].item() == poisoned_label

    print(
        now(), " poison success: ",
        p_acc, " poisoned_label: ",
        poisoned_label, " prediction: ",
        net(target_img.unsqueeze(0).to(device)).max(1)[1].item(),
    )

    # Dictionary to write contest the csv file
    stats = OrderedDict(
        [
            ("poisons path", args.poisons_path),
            ("model", args.model_path if args.model_path is not None else args.model),
            ("target class", target_class),
            ("base class", poisoned_label),
            ("num poisons", len(poison_tuples)),
            ("max perturbation norm", np.max(poison_perturbation_norms)),
            ("epoch", epoch),
            ("loss", loss),
            ("training_acc", acc),
            ("natural_acc", natural_acc),
            ("training_acc", acc),
            ("poison_acc", p_acc),
        ]
    )
    to_results_table(stats, args.output, log_name=f"results_defense={args.defense}_"
                                                  f"end2end={args.e2e}_ffe={args.ffe}_"
                                                  f"scratch={args.from_scratch}_{args.str}.csv")
    ####################################################

    return
