import random

import torch
from torch import nn
import torch.nn.functional as F
import foolbox
import heapq

import data
from utils import *
from config import opt
from utils import get_dataset_by_name, classifier_dict, load_model_weights
# import loss

def get_deepfool_dataset(sub_dataset, unqueried_sub_dataset, victim, substitute, n_query, dataset='cifar100',
                         return_idx=True):
    if len(sub_dataset) == 0:
        base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
        for i in range(len(base_dataset)):
            unqueried_sub_dataset.items.append((-1, base_dataset[i][0], -1, -1))
    n_outputs = opt.victim_n_classes
    eval_model = classifier_dict[opt.eval_model](
        n_outputs=n_outputs
    )
    substitute = load_model_weights(substitute, eval_model)
    substitute.eval()
    fmodel = foolbox.models.PyTorchModel(substitute, bounds=(-1.5, 1.5))
    batch_size = 100
    dataloader = torch.utils.data.DataLoader(
        unqueried_sub_dataset, batch_size=batch_size, shuffle=False, num_workers=4
    )
    diffs = []
    for i, (_, imgs, _, _) in enumerate(dataloader):
        imgs = imgs.cuda()
        with torch.no_grad():
            outputs = substitute(imgs)
            softmax = nn.Softmax(dim=1)
            probs = softmax(outputs)
            labels = probs.max(1)[1]
        # labels = nn.functional.one_hot(probs.max(1)[1],num_classes=10)
        # print(labels.shape)
        attack = foolbox.attacks.L2DeepFoolAttack(steps=20, overshoot=2)  # params
        adv_imgs = attack.run(fmodel, imgs, labels)
        batch_diff = F.pairwise_distance(imgs.view(batch_size, -1), adv_imgs.view(batch_size, -1))
        if i == 0:
            diffs = batch_diff
        else:
            diffs = torch.cat((diffs, batch_diff), 0)
        # for testing only
        # print(len(diffs))
        # if len(diffs) == 1000:
        #     break
    sorted, indices = torch.sort(diffs)
    index_list = [idx.item() for idx in indices[0:n_query]]
    # index_list = list(map(diffs.index, heapq.nsmallest(n_query, diffs)))
    if return_idx:
        return index_list
    else:
        new_sub_dataset = data.SubDataset()
        next_dataset = data.SubDataset()
        for idx in index_list:
            item = torch.unsqueeze(unqueried_sub_dataset[idx][1], 0).cuda()
            with torch.no_grad():
                outputs = victim(item)
                softmax = nn.Softmax(dim=1)
                probs = softmax(outputs)
                # labels = nn.functional.one_hot(probs.max(1)[1],num_classes=10)
                # labels = labels.float()
                sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))
                new_sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))
        # renew unqueried_sub_dataset
        unqueried_sub_dataset.items = [unqueried_sub_dataset.items[i] for i in range(len(unqueried_sub_dataset))
                                       if i not in index_list]
        # print(len(sub_dataset), len(unqueried_sub_dataset))
        return sub_dataset, new_sub_dataset, unqueried_sub_dataset, next_dataset


def get_kcenter_dataset(sub_dataset, unqueried_sub_dataset, victim, substitute, n_query, dataset='cifar100',
                        return_idx=True):
    if len(sub_dataset) == 0:
        base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
        for i in range(len(base_dataset)):
            unqueried_sub_dataset.items.append((-1, base_dataset[i][0], -1, -1))

    n_outputs = opt.victim_n_classes
    eval_model = classifier_dict[opt.eval_model](
        n_outputs=n_outputs
    )
    substitute = load_model_weights(substitute, eval_model)
    victim.cuda()
    victim.eval()
    substitute.cuda()
    substitute.eval()

    # first loop: random sample
    if len(sub_dataset) == 0:
        index_list = random.sample(range(len(unqueried_sub_dataset)), n_query)
    else:
        # get predictions for old data
        old_predictions = torch.randn([0, 10]).cuda()
        dataloader = torch.utils.data.DataLoader(
            sub_dataset,
            batch_size=100,
            shuffle=False,
            num_workers=4
        )
        unqueried_dataloader = torch.utils.data.DataLoader(
            unqueried_sub_dataset,
            batch_size=100,
            shuffle=False,
            num_workers=4
        )
        for _, (_, old_items, _, _) in enumerate(dataloader):
            with torch.no_grad():
                outputs = substitute(old_items.cuda())
                softmax = nn.Softmax(dim=1)
                probs = softmax(outputs)
                old_predictions = torch.cat((old_predictions, probs), 0)

        # calculate min distance for each item
        min_dist_list = []
        for i, (_, unqueried_items, _, _) in enumerate(unqueried_dataloader):
            unqueried_items = unqueried_items.cuda()
            with torch.no_grad():
                predictions = substitute(unqueried_items)
            for prediction in predictions:
                item_dist_list = F.pairwise_distance(prediction, old_predictions)
                # for j, old_prediction in enumerate(old_predictions):
                #     item_dist = torch.dist(prediction, old_prediction, 2)
                #     if j == 0:
                #         item_dist_list = item_dist
                #     else:
                #         item_dist_list = torch.cat((item_dist_list, item_dist), 0)
                min_dist_list.append(torch.min(item_dist_list).item())
                # for testing
                # if i == 10:
                #     break

        # get max min distance
        index_list = list(map(min_dist_list.index, heapq.nlargest(n_query, min_dist_list)))

    if return_idx:
        return index_list
    else:
        new_sub_dataset = data.SubDataset()
        next_dataset = data.SubDataset()
        for idx in index_list:
            item = torch.unsqueeze(unqueried_sub_dataset[idx][1], 0).cuda()
            with torch.no_grad():
                outputs = victim(item)
                softmax = nn.Softmax(dim=1)
                probs = softmax(outputs)
                # labels = nn.functional.one_hot(probs.max(1)[1],num_classes=10)
                # labels = labels.float()
                sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))
                new_sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))
        # renew unqueried_sub_dataset
        unqueried_sub_dataset.items = [unqueried_sub_dataset.items[i] for i in range(len(unqueried_sub_dataset))
                                       if i not in index_list]
        # print(len(sub_dataset), len(unqueried_sub_dataset))
        return sub_dataset, new_sub_dataset, unqueried_sub_dataset, next_dataset


def get_random_dataset(sub_dataset, unqueried_sub_dataset, victim, n_query, dataset='cifar100', return_idx=True):
    if len(sub_dataset) == 0:
        base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
        for i in range(len(base_dataset)):
            unqueried_sub_dataset.items.append((-1, base_dataset[i][0], -1, -1))

    victim.cuda()
    victim.eval()

    index_list = random.sample(range(len(unqueried_sub_dataset)), n_query)

    if return_idx:
        return index_list
    else:
        new_sub_dataset = data.SubDataset()
        next_dataset = data.SubDataset()
        for idx in index_list:
            item = torch.unsqueeze(unqueried_sub_dataset[idx][1], 0).cuda()
            with torch.no_grad():
                outputs = victim(item)
                softmax = nn.Softmax(dim=1)
                probs = softmax(outputs)
                # labels = nn.functional.one_hot(probs.max(1)[1],num_classes=10)
                # labels = labels.float()
                sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))
                new_sub_dataset.items.append((-1, item[0].cpu(), probs[0].cpu(), -1))
        # renew unqueried_sub_dataset
        unqueried_sub_dataset.items = [unqueried_sub_dataset.items[i] for i in range(len(unqueried_sub_dataset))
                                       if i not in index_list]
        # print(len(sub_dataset), len(unqueried_sub_dataset))
        return sub_dataset, new_sub_dataset, unqueried_sub_dataset, next_dataset