import sys
sys.path.append('..')

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable

import data
from utils import *
from config import opt
from utils import get_dataset_by_name
# import loss

def get_adv_dataset(n_query, victim, data_gen, sub_dataset, dataset='cifar100'):
    new_sub_dataset = data.SubDataset()
    next_dataset = data.SubDataset()
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
    if dataset == 'cifar100':
        forbidden = [i for i in range(90, 100)]
        indexes = [i for i, value in enumerate(base_dataset.targets) if value not in forbidden]
        base_dataset.train_data = base_dataset.data[indexes]
        base_dataset.train_labels = [base_dataset.targets[i] for i in indexes]
    if opt.use_gpu:
        victim.cuda()
        data_gen.cuda()
    if opt.same_origin:
        start = 0
    else:
        start = len(sub_dataset)
    for query in range(n_query):
        item = base_dataset[start + query][0].view(1, 3, 32, 32)
        next_item = base_dataset[start + n_query + query][0].view(1, 3, 32, 32)
        if opt.use_gpu:
            item = item.cuda()
        with torch.no_grad():
            adv_item = data_gen(item)
            output = victim(adv_item)
        softmax = nn.Softmax(dim=1)
        prob = softmax(output)
        # print((item[0].cpu(),adv_item[0].cpu(),prob[0].cpu()))
        sub_dataset.items.append((item[0].cpu(), adv_item[0].cpu(), prob[0].cpu(), next_item[0].cpu()))
        new_sub_dataset.items.append((item[0].cpu(), adv_item[0].cpu(), prob[0].cpu(), next_item[0].cpu()))
    start = len(sub_dataset)
    for query in range(n_query):
        item = base_dataset[start + query][0].view(1, 3, 32, 32)
        next_dataset.items.append((item[0].cpu(), -1, -1, -1))
    return sub_dataset, new_sub_dataset, next_dataset


def get_papernot_dataset(sub_dataset, victim, substitute, init_n_per_class, dataset='cifar100', lamb=0.1):
    if dataset == 'cifar10':
        n_class = 10
    elif dataset == 'cifar100':
        n_class = 100
    else:
        raise NotImplementedError('Unknown dataset')
    batch_size = int(init_n_per_class * n_class / 5)
    old_size = len(sub_dataset)
    assert old_size % batch_size == 0

    new_sub_dataset = data.SubDataset()
    next_dataset = data.SubDataset()
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
    softmax = nn.Softmax(dim=1)
    if opt.use_gpu:
        victim.cuda()
        substitute.cuda()
    victim.eval()
    substitute.eval()

    def get_adv(clean_items, model, labels, lamb=0.1):
        data = Variable(clean_items.detach(), requires_grad=True)
        outputs = model(data)
        # outputs = model(clean_items)
        # probs = F.log_softmax(outputs, dim=1)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        data_grad = data.grad.data.sign()
        adv_items = clean_items.detach() + lamb * data_grad.detach()
        return adv_items

    if old_size == 0:
        # get init sample list
        init_samples = torch.randn([0, 3, 32, 32])
        class_num_list = np.zeros(n_class)
        i = 0
        while init_samples.shape[0] < init_n_per_class * n_class:
            img, label = base_dataset[i]
            if class_num_list[label] < init_n_per_class:
                init_samples = torch.cat((init_samples, img.unsqueeze(0)), 0)
                class_num_list[label] += 1
                # print(i)
                # print(img)
                # print(label)
            i += 1
        # print(init_samples.shape)
        # print(class_num_list)
        assert init_samples.shape[0] == init_n_per_class * n_class
        # query victim
        for start in range(0, init_samples.shape[0], batch_size):
            items = init_samples[start:start + batch_size]
            if opt.use_gpu:
                items = items.cuda()
            with torch.no_grad():
                outputs = victim(items)
                probs = softmax(outputs)
                # print(probs)
            for i in range(batch_size):
                sub_dataset.items.append((-1, items[i].cpu(), probs[i].cpu(), -1))
                new_sub_dataset.items.append((-1, items[i].cpu(), probs[i].cpu(), -1))
    else:
        all_dataloader = torch.utils.data.DataLoader(
            sub_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4
        )
        for _, (_, imgs, probs, _) in enumerate(all_dataloader):
            if opt.use_gpu:
                imgs = imgs.cuda()
                probs = probs.cuda()
            # print(f'probs:{probs}')
            labels = probs.max(1)[1]
            # perturb images with FGSM
            adv_imgs = get_adv(imgs, substitute, labels, lamb=lamb)
            # print(f'imgs:{imgs}')
            # print(f'adv_imgs:{adv_imgs}')
            # print(adv_imgs.requires_grad)
            # query victim
            adv_outputs = victim(adv_imgs)
            adv_probs = softmax(adv_outputs).detach()
            # print(f'adv_probs:{adv_probs}')
            # print(adv_probs.requires_grad)
            for i in range(batch_size):
                sub_dataset.items.append((-1, adv_imgs[i].cpu(), adv_probs[i].cpu(), -1))
                new_sub_dataset.items.append((-1, adv_imgs[i].cpu(), adv_probs[i].cpu(), -1))

    return sub_dataset, new_sub_dataset, next_dataset