import sys
sys.path.append('..')

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import foolbox
import heapq
from torch.autograd import Variable

import data
from utils import *
from config import opt
from utils import get_dataset_by_name
# import loss

# get substitute dataset by querying victim
def get_sub_dataset(n_query, victim, data_gen,
                    substitute, sub_dataset):
    new_sub_dataset = data.SubDataset()
    # generate data with data generator
    batch_size = 20
    n_batch = n_query // batch_size
    for n in range(n_batch):
        seeds = torch.Tensor(
            np.random.uniform(-5., 5., size=(batch_size, opt.noise_dim))
        )
        if opt.use_gpu:
            seeds = seeds.cuda()
            victim.cuda()
            data_gen.cuda()
        with torch.no_grad():
            items = data_gen(seeds)
            # query victim model and update substitute dataset
            outputs = victim(items)
            softmax = nn.Softmax(dim=1)
            probs = softmax(outputs)
        for query in range(batch_size):
            sub_dataset.items.append((seeds[query].cpu(), items[query].cpu(), probs[query].cpu()))
            new_sub_dataset.items.append((seeds[query].cpu(), items[query].cpu(), probs[query].cpu()))
    return sub_dataset, new_sub_dataset


def get_baseline_dataset(n_query, victim, sub_dataset, dataset='cifar10'):
    new_sub_dataset = data.SubDataset()
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
    if dataset == 'cifar100':
        forbidden = [i for i in range(90, 100)]
        indexes = [i for i, value in enumerate(base_dataset.train_labels) if value not in forbidden]
        base_dataset.train_data = base_dataset.train_data[indexes]
        base_dataset.train_labels = [base_dataset.train_labels[i] for i in indexes]
    # print(base_dataset)
    start = len(sub_dataset)
    for query in range(n_query):
        item = base_dataset[start + query][0].view(1, 3, 32, 32)
        if opt.use_gpu:
            item = item.cuda()
        with torch.no_grad():
            output = victim(item)
        softmax = nn.Softmax(dim=1)
        prob = softmax(output)
        sub_dataset.items.append((0, item[0].cpu(), prob[0].cpu()))
        new_sub_dataset.items.append((0, item[0].cpu(), prob[0].cpu()))
    return sub_dataset, new_sub_dataset