import os
import copy
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import logging
import datetime

import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from utils.conv import *
from utils.resnet import *
from utils.utils import *
from utils.parser import *
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cdist

class LinfPGDAttack(object):
    def __init__(self, model, epsilon, num_steps, step_size):
        self.model = model
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        x = x_natural.detach()
        x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        #x = x + torch.zeros_like(x)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, x_natural - self.epsilon), x_natural + self.epsilon)
            x = torch.clamp(x, 0, 1)
        return x

class LinfCWAttack(object):
    def __init__(self, model, epsilon, num_steps, step_size):
        self.model = model
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size

        def cw_loss_func(logit, target):
            label_mask = torch.zeros_like(logit)
            label_mask.scatter_(1, target.unsqueeze(1), 1.0)
            correct_logit = torch.sum(label_mask * logit, dim=1)
            wrong_logit = torch.max((1 - label_mask) * logit - 1e4 * label_mask, dim=1)[0]
            return -F.relu(correct_logit - wrong_logit + 50).mean()

        self.loss_func = cw_loss_func

    def perturb(self, x_natural, y):
        x = x_natural.detach()
        x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = self.loss_func(logits, y)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, x_natural - self.epsilon), x_natural + self.epsilon)
            x = torch.clamp(x, 0, 1)
        return x

class FGSMAttack(object):
    def __init__(self, model, epsilon):
        self.model = model
        self.epsilon = epsilon
        self.loss = nn.CrossEntropyLoss()

    def perturb(self, inputs, targets):
        inputs.requires_grad = True
        outputs = self.model(inputs)
        self.model.zero_grad()
        cost = self.loss(outputs, targets)
        cost.backward()

        perturbed_inputs = inputs + self.epsilon * inputs.grad.sign()
        perturbed_inputs = torch.clamp(perturbed_inputs, 0, 1)
        return perturbed_inputs

def empirical_search(args, net_adv, k, target_id, knn_data, sub_set, test_dataset, train_dataset):
    search_ids = []
    net = copy.deepcopy(net_adv)
    for i in range(args.c // args.t):
        max_loss = -10000
        best_poi = []
        for j in range(len(knn_data) // args.t):
            # temporary unlearn a training sample
            tmp_ids = knn_data[j:j + args.t]
            tmp_unlearn_data = torch.utils.data.Subset(train_dataset, tmp_ids)
            tmp_unlearn_loader = torch.utils.data.DataLoader(tmp_unlearn_data, batch_size=len(tmp_unlearn_data), shuffle=False, num_workers=1)

            tmp_unlearn = copy.deepcopy(net)
            diff = diff = get_grad_diff(args, tmp_unlearn, tmp_unlearn_loader)
            d_theta = diff

            tmp_unlearn.eval()
            with torch.no_grad():
                for p in tmp_unlearn.parameters():
                    if p.requires_grad:
                        new_p = p - args.tau * d_theta.pop(0)
                        p.copy_(new_p)

            # compute density after unlearning
            tmp_select_data = torch.utils.data.Subset(test_dataset, [target_id])
            tmp_select_loader = torch.utils.data.DataLoader(tmp_select_data, batch_size=len(tmp_select_data), shuffle=False, num_workers=1)

            sub_train_data = torch.utils.data.Subset(train_dataset, sub_set)
            sub_train_loader = torch.utils.data.DataLoader(sub_train_data, batch_size=100, shuffle=False, num_workers=1)

            tmp_test_features = []
            tmp_train_features = []

            for batch_idx, (inputs, targets) in enumerate(tmp_select_loader):
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                out_layer = tmp_unlearn(inputs, out_keys=True)
                features = out_layer['avgpool'].detach().cpu().tolist()
                tmp_test_features += features

            for batch_idx, (inputs, targets) in enumerate(sub_train_loader):
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                out_layer = tmp_unlearn(inputs, out_keys=True)
                features = out_layer['avgpool'].detach().cpu().tolist()
                tmp_train_features += features

            tmp_dists = cdist(tmp_test_features, tmp_train_features)
            sort_ids = np.argsort(tmp_dists)[0]

            data = [tmp_test_features[0]]
            count = 0
            for idx in sort_ids:
                if count >= k:
                    break
                if sub_set[idx] not in search_ids:
                    data += [tmp_train_features[idx]]
                    count += 1

            data = np.array(data)
            density = get_lid(data, k)

            dim_ids = []
            count = 0
            for idx in sort_ids:
                if count >= args.dim:
                    break
                if sub_set[idx] not in search_ids and sub_set[idx] not in tmp_ids:
                    dim_ids += [sub_set[idx]]
                    count += 1

            # compute gradient on average pooling layer
            input, target = test_dataset[target_id]
            input = input.unsqueeze(0).to(args.device)
            target = torch.tensor(target).unsqueeze(0)

            with torch.no_grad():
                tmp_unlearn.eval()
                out_layer = tmp_unlearn(input, out_keys=True)
                embed = out_layer['avgpool']
                confs = torch.softmax(torch.matmul(embed, tmp_unlearn.module.linear.weight.T), dim=1).cpu().detach()
                embed = embed.detach().cpu()
                g0 = confs - torch.eye(10)[target.long()]
                gradient = g0.cpu().detach().numpy()

            cos = 0
            for idx in dim_ids:
                train_input, _ = train_dataset[idx]
                train_input = train_input.unsqueeze(0).to(args.device)
                train_embed = tmp_unlearn(train_input)
                input_embed = tmp_unlearn(input)
                vec = (train_embed - input_embed).view(1, -1).detach().cpu().numpy()

                cos += cosine_similarity(vec, gradient)[0][0]

            adv_loss = density + args.lam * -cos / len(dim_ids)

            if adv_loss > max_loss:
                max_loss = adv_loss
                best_poi = tmp_ids

        best_unlearn_data = torch.utils.data.Subset(train_dataset, best_poi)
        best_unlearn_loader = torch.utils.data.DataLoader(best_unlearn_data, batch_size=len(best_unlearn_data), shuffle=False, num_workers=1)

        diff = get_grad_diff(args, net, best_unlearn_loader)
        d_theta = diff

        net.eval()
        with torch.no_grad():
            for p in net.parameters():
                if p.requires_grad:
                    new_p = p - args.tau * d_theta.pop(0)
                    p.copy_(new_p)

        # add and remove training sample
        search_ids += best_poi
        knn_data = [x for x in knn_data if x not in best_poi]

    return search_ids

def test(args, logger, net, test_loader):
    if args.adv == 'pgd':
        adversary = LinfPGDAttack(net, epsilon=0.0314, num_steps=10, step_size=0.00784)
    elif args.adv == 'cw':
        adversary = LinfCWAttack(net, epsilon=0.0314, num_steps=30, step_size=0.00784)
    elif args.adv == 'fgsm':
        adversary = FGSMAttack(net, epsilon=0.0314)

    criterion = nn.CrossEntropyLoss()

    print('[ Test Start ]')
    net.eval()
    benign_loss = 0
    adv_loss = 0
    benign_correct = []
    adv_correct = []
    total = 0
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        total += targets.size(0)

        outputs = net(inputs)
        loss = criterion(outputs, targets)
        benign_loss += loss.item()

        _, predicted = outputs.max(1)
        benign_correct += predicted.eq(targets).detach().cpu().tolist()

        adv = adversary.perturb(inputs, targets)
        adv_outputs = net(adv)
        loss = criterion(adv_outputs, targets)
        adv_loss += loss.item()

        _, predicted = adv_outputs.max(1)
        adv_correct += predicted.eq(targets).detach().cpu().tolist()

    begin_true = benign_correct.count(True)
    adv_true = adv_correct.count(True)
    print(f'Total benign test accuarcy: {100. * begin_true / total}%')
    logger.info(f'Total benign test accuarcy: {100. * begin_true / total}%')
    print(f'Total adversarial test accuarcy: {100. * adv_true / total}%')
    logger.info(f'Total adversarial test accuarcy: {100. * adv_true / total}%')

    return benign_correct, adv_correct

def main():
    args = get_parameter()

    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    log_path = args.craftproj + '-log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S"))
    log_path = log_path + '.txt'
    logging.basicConfig(
        filename=os.path.join(args.logdir, log_path),
        format="%(asctime)s - %(name)s - %(message)s",
        datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO, filemode='w')

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.info(str(args))

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'device: {args.device}')
    logger.info(f'device: {args.device}')

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=False, num_workers=4)

    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)

    if args.dataset == 'CIFAR':
        net_basic = ResNet18().to(args.device)
        net_basic = torch.nn.DataParallel(net_basic)
        cudnn.benchmark = True
        checkpoint = torch.load('./models/basic_training')
        net_basic.load_state_dict(checkpoint['net'])

        net_adv = ResNet18().to(args.device)
        net_adv = torch.nn.DataParallel(net_adv)
        cudnn.benchmark = True
        checkpoint = torch.load('./models/pgd_adversarial_training')
        net_adv.load_state_dict(checkpoint['net'])

    _, correct_basic = test(args, logger, net_basic, test_loader)
    _, correct_adv = test(args, logger, net_adv, test_loader)

    # random select test samples that are attacked succeed (incorrect) in basic training but failed (correct) in adversarial training
    samples = [i for i in range(len(correct_basic)) if correct_basic[i]==False and correct_adv[i]==True]
    select_ids = random.sample(samples, args.test_samples)
    select_data = torch.utils.data.Subset(test_dataset, select_ids)
    select_loader = torch.utils.data.DataLoader(select_data, batch_size=len(select_data), shuffle=False, num_workers=1)

    # input space
    input_blobs_test = []
    input_blobs_train = []
    for batch_idx, (inputs, targets) in enumerate(select_loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        features = inputs.view(len(select_data), -1).detach().cpu().tolist()
        input_blobs_test += features
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        features = inputs.view(100, -1).detach().cpu().tolist()
        input_blobs_train += features
    input_blobs_test = np.array(input_blobs_test)
    input_blobs_train = np.array(input_blobs_train)

    # representation space
    features_blobs_test = []
    features_blobs_train = []

    for batch_idx, (inputs, targets) in enumerate(select_loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        out_layer = net_adv(inputs, out_keys=True)
        features = out_layer['avgpool'].detach().cpu().tolist()
        features_blobs_test += features

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        out_layer = net_adv(inputs, out_keys=True)
        features = out_layer['avgpool'].detach().cpu().tolist()
        features_blobs_train += features

    features_blobs_test = np.array(features_blobs_test)
    features_blobs_train = np.array(features_blobs_train)

    for k in args.K:
        input_ids = []
        for i in range(args.test_samples):
            dists = []
            for j in range(len(input_blobs_train)):
                dists.append(np.linalg.norm(input_blobs_test[i] - input_blobs_train[j]))
            knn_data = np.argsort(dists)[:args.c].tolist()
            input_ids.append(knn_data)

        rand_ids = []
        search_ids = []
        for i in range(len(features_blobs_test)):
            print(f'test sample {i}')
            dists = []
            for j in range(len(features_blobs_train)):
                dists.append(np.linalg.norm(features_blobs_test[i] - features_blobs_train[j]))
            subset = np.argsort(dists)[:args.P]
            knn_data = subset[:k].tolist()

            rand_ids.append(random.sample(range(len(features_blobs_train)), args.c))
            search_ids.append(empirical_search(args, net_adv, k, select_ids[i], knn_data, subset, test_dataset, train_dataset))

        # unlearn random baseline
        unique_rand_ids = get_unique(rand_ids)
        print(len(unique_rand_ids))
        unlearn_data = torch.utils.data.Subset(train_dataset, unique_rand_ids)
        unlearn_loader = torch.utils.data.DataLoader(unlearn_data, batch_size=len(unlearn_data), shuffle=False, num_workers=1)
        net_unlearn = copy.deepcopy(net_adv)

        # first-order unlearning method
        diff = get_grad_diff(args, net_unlearn, unlearn_loader)
        d_theta = diff

        net_unlearn.eval()
        with torch.no_grad():
            for p in net_unlearn.parameters():
                if p.requires_grad:
                    new_p = p - args.tau * d_theta.pop(0)
                    p.copy_(new_p)

        print(f'==> unlearning random = {args.c}')
        logger.info(f'==> unlearning random = {args.c}')
        benign_correct, adv_correct = test(args, logger, net_unlearn, select_loader)

        # unlearn knn baseline
        unique_input_ids = get_unique(input_ids)
        print(len(unique_input_ids))
        unlearn_data = torch.utils.data.Subset(train_dataset, unique_input_ids)
        unlearn_loader = torch.utils.data.DataLoader(unlearn_data, batch_size=len(unlearn_data), shuffle=False, num_workers=1)
        net_unlearn = copy.deepcopy(net_adv)

        # first-order unlearning method
        diff = get_grad_diff(args, net_unlearn, unlearn_loader)
        d_theta = diff

        net_unlearn.eval()
        with torch.no_grad():
            for p in net_unlearn.parameters():
                if p.requires_grad:
                    new_p = p - args.tau * d_theta.pop(0)
                    p.copy_(new_p)

        print(f'==> unlearning knn = {args.c}')
        logger.info(f'==> unlearning knn = {args.c}')
        benign_correct, adv_correct = test(args, logger, net_unlearn, select_loader)

        # unlearn ours
        unique_search_ids = get_unique(search_ids)
        print(len(unique_search_ids))
        unlearn_data = torch.utils.data.Subset(train_dataset, unique_search_ids)
        unlearn_loader = torch.utils.data.DataLoader(unlearn_data, batch_size=len(unlearn_data), shuffle=False, num_workers=1)
        net_unlearn = copy.deepcopy(net_adv)

        # first-order unlearning method
        diff = get_grad_diff(args, net_unlearn, unlearn_loader)
        d_theta = diff

        net_unlearn.eval()
        with torch.no_grad():
            for p in net_unlearn.parameters():
                if p.requires_grad:
                    new_p = p - args.tau * d_theta.pop(0)
                    p.copy_(new_p)

        print(f'==> unlearning search = {args.c}')
        logger.info(f'==> unlearning search = {args.c}')
        benign_correct, adv_correct = test(args, logger, net_unlearn, select_loader)

if __name__ == '__main__':
    main()