import random
import math

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import foolbox
import heapq
from torch.autograd import Variable

import data
from utils import *
from config import opt
from utils import get_dataset_by_name
# import loss

def get_mosafi_dataset(victim, n_query, dataset):

    victim.eval()
    if opt.use_gpu:
        victim.cuda()
    new_sub_dataset = data.SubDataset()
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
    softmax = nn.Softmax(dim=1)

    for i in range(n_query):
        i1 = random.randint(0, len(base_dataset) - 1)
        i2 = random.randint(0, len(base_dataset) - 1)
        while i1 == i2:
            i2 = random.randint(0, len(base_dataset) - 1)
        # print(i1, i2)
        p = random.random()
        # p = 1
        item = p*base_dataset[i1][0] + (1-p)*base_dataset[i2][0]

        item = torch.unsqueeze(item, 0).cuda()
        with torch.no_grad():
            outputs = victim(item)
            probs = softmax(outputs)
            # print(item[0].cpu(), probs[0].cpu())
            new_sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))

    return new_sub_dataset


def get_avg_dataset(victim, n_query, dataset):

    victim.eval()
    if opt.use_gpu:
        victim.cuda()
    new_sub_dataset = data.SubDataset()
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
    softmax = nn.Softmax(dim=1)

    for i in range(n_query):
        idx_list = random.sample([_ for _ in range(len(base_dataset))], opt.n_fuse)
        items = [base_dataset[idx][0] for idx in idx_list]
        item = sum(items)/len(items)

        item = torch.unsqueeze(item, 0).cuda()
        with torch.no_grad():
            outputs = victim(item)
            probs = softmax(outputs)
            # print(item[0].cpu(), probs[0].cpu())
            new_sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))

    return new_sub_dataset