import random
import math

import torch
from torch import nn

import data
from utils import *
from config import opt
from utils import get_dataset_by_name
from .query_victim import query_victim
# import loss


def get_fusion_pre_dataset(n_fuse, n_query, dataset='cifar100'):
    pre_dataset = data.SubDataset()
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset
    if dataset == 'cifar100':
        forbidden = [i for i in range(90, 100)]
        indexes = [i for i, value in enumerate(base_dataset.targets) if value not in forbidden]
        base_dataset.train_data = base_dataset.data[indexes]
        base_dataset.train_labels = [base_dataset.targets[i] for i in indexes]
    start = 0
    for query in range(n_query):
        items = []
        for i in range(n_fuse):
            items.append(base_dataset[start + query * n_fuse + i][0].view(1, 3, 32, 32).cuda())
        items = tuple([item[0].cpu() for item in items])
        pre_dataset.items.append((items, -1, -1, -1))
    return pre_dataset


# def get_fusion_dataset(n_fuse, n_query, victim, data_gen, sub_dataset, unlabeled_dataset, dataset='cifar100'):
#     new_sub_dataset = data.SubDataset()
#     next_dataset = data.SubDataset()
#     dataset_parse = dataset.split('-')
#     dataset_name = dataset_parse[0]
#     partition = dataset_parse[1] if len(dataset_parse) > 1 else None
#     base_dataset = dataset_dict[dataset_name](input_size=opt.victim_img_size, partition=partition).train_dataset
    # if dataset == 'cifar100':
    #     forbidden = [i for i in range(90, 100)]
    #     indexes = [i for i, value in enumerate(base_dataset.targets) if value not in forbidden]
#     #     base_dataset.train_data = base_dataset.data[indexes]
#     #     base_dataset.train_labels = [base_dataset.targets[i] for i in indexes]
#     if len(sub_dataset) == 0:
#         for i in range(len(base_dataset)):
#             unlabeled_dataset.items.append((base_dataset[i][0], -1, -1, 0))
#     if opt.use_gpu:
#         victim.cuda()
#         data_gen.cuda()
#     if opt.same_origin:
#         start = 0
#     else:
#         start = len(sub_dataset) * n_fuse
#
#     for query in range(n_query):
#         items = []
#         next_items = []
#         for i in range(n_fuse):
#             items.append(base_dataset[start + query * n_fuse + i][0].view(1, 3, 32, 32).cuda())
#         for i in range(n_fuse):
#             next_items.append(base_dataset[start + n_query * n_fuse + query * n_fuse + i][0].view(1, 3, 32, 32).cuda())
#         with torch.no_grad():
#             adv_item = data_gen(tuple(items))
#             output = victim(adv_item)
#         softmax = nn.Softmax(dim=1)
#         prob = softmax(output)
#         items = tuple([item[0].cpu() for item in items])
#         next_items = tuple([item[0].cpu() for item in next_items])
#         sub_dataset.items.append((items, adv_item[0].cpu(), prob[0].cpu(), next_items))
#         new_sub_dataset.items.append((items, adv_item[0].cpu(), prob[0].cpu(), next_items))
#     start = len(sub_dataset) * n_fuse
#     for query in range(n_query):
#         items = []
#         for i in range(n_fuse):
#             items.append(base_dataset[start + query * n_fuse + i][0].view(1, 3, 32, 32)[0].cpu())
#         next_dataset.items.append((items, -1, -1, -1))
#     return sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset

# def get_fusion_dataset(n_fuse, n_query, victim, data_gen,
#                        sub_dataset, unlabeled_dataset, next_diff_dataset, dataset='cifar100'):
#     new_sub_dataset = data.SubDataset()
#     next_dataset = data.SubDataset()
#     dataset_parse = dataset.split('-')
#     dataset_name = dataset_parse[0]
#     partition = dataset_parse[1] if len(dataset_parse) > 1 else None
#     base_dataset = dataset_dict[dataset_name](input_size=opt.victim_img_size, partition=partition).train_dataset
#     if len(sub_dataset) == 0:
#         for i in range(len(base_dataset)):
#             unlabeled_dataset.items.append([base_dataset[i][0], -1, -1, 0])
#     if opt.use_gpu:
#         victim.cuda()
#         data_gen.cuda()
#     softmax = nn.Softmax(dim=1)
#
#     if len(sub_dataset) == 0:
#         init_idx_list = random.sample([i for i in range(len(unlabeled_dataset))], n_query*n_fuse)
#         # print(init_idx_list)
#         for query in range(n_query):
#             items = []
#             for i in range(n_fuse):
#                 items.append(unlabeled_dataset.items[init_idx_list[query * n_fuse + i]][0].view(
#                     1, 3, opt.public_img_size, opt.public_img_size).cuda())
#                 unlabeled_dataset.items[init_idx_list[query * n_fuse + i]][3] = 1
#             with torch.no_grad():
#                 adv_item = data_gen(tuple(items))
#                 output = victim(adv_item)
#             prob = softmax(output)
#             items = tuple([item[0].cpu() for item in items])
#             sub_dataset.items.append((items, adv_item[0].cpu(), prob[0].cpu(), -1))
#             new_sub_dataset.items.append((items, adv_item[0].cpu(), prob[0].cpu(), -1))
#
#         # for query in range(n_query):
#         #     items = []
#         #     for i in range(n_fuse):
#         #         items.append(unlabeled_dataset.items[query * n_fuse + i][0].view(1, 3, opt.public_img_size, opt.public_img_size).cuda())
#         #         unlabeled_dataset.items[query * n_fuse + i][3] = 1
#         #     with torch.no_grad():
#         #         adv_item = data_gen(tuple(items))
#         #         output = victim(adv_item)
#         #     prob = softmax(output)
#         #     items = tuple([item[0].cpu() for item in items])
#         #     sub_dataset.items.append((items, adv_item[0].cpu(), prob[0].cpu(), -1))
#         #     new_sub_dataset.items.append((items, adv_item[0].cpu(), prob[0].cpu(), -1))
#     else:
#         batch_size = 100
#         next_diff_dataloader = torch.utils.data.DataLoader(
#             next_diff_dataset,
#             batch_size=batch_size,
#             num_workers=4
#         )
#         for i, (items, _, _, _) in enumerate(next_diff_dataloader):
#             items = tuple([item.cuda() for item in items])
#             with torch.no_grad():
#                 adv_items = data_gen(items)
#                 outputs = victim(adv_items)
#             probs = softmax(outputs)
#             for j in range(len(probs)):
#                 next_diff_dataset.items[i * batch_size + j] = \
#                     (next_diff_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1)
#         new_sub_dataset = next_diff_dataset
#         sub_dataset.items = sub_dataset.items + next_diff_dataset.items
#
#     return sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset


def get_large_dataset_by_noise(dataset, target_size, noise_std, noise_mean=0.):
    large_dataset = data.LabeledDataset()
    source_size = len(dataset)
    target_size = float(target_size)
    times = math.ceil(target_size/source_size)
    for _ in range(times-1):
        for i in range(source_size):
            img, target = dataset[i]
            noise = torch.randn(img.shape)
            noise = noise*noise_std + noise_mean
            large_dataset.items.append((img+noise, target))
    large_dataset = torch.utils.data.ConcatDataset([large_dataset, dataset])
    print(f'dataset enlarged: {source_size} -> {target_size}')
    return large_dataset

def get_large_dataset_by_interpolate(dataset, target_size):
    large_dataset = data.LabeledDataset()
    source_size = len(dataset)
    for _ in range(target_size-source_size):
        i1 = random.randint(0, source_size - 1)
        i2 = random.randint(0, source_size - 1)
        while i1 == i2:
            i2 = random.randint(0, source_size - 1)
        # print(i1, i2)
        p = random.random()
        # p = 1
        img = p * dataset[i1][0] + (1 - p) * dataset[i2][0]
        large_dataset.items.append((img, 0))
    large_dataset = torch.utils.data.ConcatDataset([large_dataset, dataset])
    print(f'dataset enlarged: {source_size} -> {target_size}')
    return large_dataset


def get_base_dataset(n_fuse, n_query, dataset='cifar100'):
    base_dataset = get_dataset_by_name(dataset, opt.victim_img_size).train_dataset

    # if n_fuse * n_query > len(base_dataset):
    #     # base_dataset = get_large_dataset_by_noise(
    #     #     base_dataset, n_fuse * n_query, noise_std=opt.large_dataset_noise_std, noise_mean=0.)
    #
    #     base_dataset = get_large_dataset_by_interpolate(
    #         base_dataset, n_fuse * n_query)

    return base_dataset


# dataset generation for large query budget
def get_fusion_dataset(base_dataset, n_fuse, n_query, victim, data_gen,
                       sub_dataset, unlabeled_dataset, next_diff_dataset, dataset='cifar100',
                       victim_return_type='original', cip_ood_counter=None):
    new_sub_dataset = data.SubDataset()
    next_dataset = data.SubDataset()


    if opt.use_gpu:
        victim.cuda()
        data_gen.cuda()
    softmax = nn.Softmax(dim=1)
    batch_size = 1000#5000

    if len(sub_dataset) == 0 or opt.enlarge_every_loop:
        unlabeled_dataset.items = []
        for i in range(len(base_dataset)):
            unlabeled_dataset.items.append([base_dataset[i][0], -1, -1, 0])
        first_base_dataset = data.LabeledDataset()
        times = math.ceil(n_query / (len(unlabeled_dataset) // n_fuse))
        if times > 1:
            for _ in range(times):
                random_idx_list = random.sample([i for i in range(len(unlabeled_dataset))], len(unlabeled_dataset))
                for query in range(len(unlabeled_dataset) // n_fuse):
                    items = []
                    for i in range(n_fuse):
                        items.append(unlabeled_dataset.items[random_idx_list[query * n_fuse + i]][0].view(
                            3, opt.public_img_size, opt.public_img_size))
                        unlabeled_dataset.items[random_idx_list[query * n_fuse + i]][3] = 1
                    first_base_dataset.items.append((items, 0))
            first_base_dataset.items = first_base_dataset.items[:n_query]
        else:
            random_idx_list = random.sample([i for i in range(len(unlabeled_dataset))], n_query * n_fuse)
            for query in range(n_query):
                items = []
                for i in range(n_fuse):
                    items.append(unlabeled_dataset.items[random_idx_list[query * n_fuse + i]][0].view(
                        3, opt.public_img_size, opt.public_img_size))
                    unlabeled_dataset.items[random_idx_list[query * n_fuse + i]][3] = 1
                first_base_dataset.items.append((items, 0))

        first_base_dataloader = torch.utils.data.DataLoader(
            first_base_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        for i, (items, _) in enumerate(first_base_dataloader):
            items = tuple([item.cuda() for item in items])
            with torch.no_grad():
                adv_items = data_gen(items)
            probs = query_victim(adv_items, victim, cip_ood_counter)
            #     outputs = victim(adv_items)
            # if opt.victim_wm_dataset:
            #     probs = softmax(outputs[-1])
            # else:
            #     probs = softmax(outputs)

            # # truncate victim return
            # if victim_return_type == 'label':
            #     probs = probs.max(1)[1]
            # elif 'top' in victim_return_type:
            #     k = int(victim_return_type.split('-')[-1])
            #     topk_scores, topk_indices = probs.topk(k, dim=1)
            #     top_probs = torch.zeros_like(probs)
            #     top_probs.scatter_(1, topk_indices, topk_scores)
            #     probs = top_probs
            # elif 'round' in victim_return_type:
            #     k = int(victim_return_type.split('-')[-1])
            #     probs = torch.round(probs*10**k)/10**k

            for j in range(len(probs)):
                new_sub_dataset.items.append((first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))
                sub_dataset.items.append((first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))

        # unlabeled_dataset.items = []
        # for i in range(len(base_dataset)):
        #     unlabeled_dataset.items.append([base_dataset[i][0], -1, -1, 0])
        # init_idx_list = random.sample([i for i in range(len(unlabeled_dataset))], n_query * n_fuse)
        # first_base_dataset = data.LabeledDataset()
        # for query in range(n_query):
        #     items = []
        #     for i in range(n_fuse):
        #         items.append(unlabeled_dataset.items[init_idx_list[query * n_fuse + i]][0].view(
        #             3, opt.public_img_size, opt.public_img_size))
        #         unlabeled_dataset.items[init_idx_list[query * n_fuse + i]][3] = 1
        #     first_base_dataset.items.append((items, 0))
        #
        # first_base_dataloader = torch.utils.data.DataLoader(
        #     first_base_dataset,
        #     batch_size=batch_size,
        #     num_workers=4
        # )
        # for i, (items, _) in enumerate(first_base_dataloader):
        #     items = tuple([item.cuda() for item in items])
        #     with torch.no_grad():
        #         adv_items = data_gen(items)
        #         outputs = victim(adv_items)
        #     if opt.victim_wm_dataset:
        #         probs = softmax(outputs[-1])
        #     else:
        #         probs = softmax(outputs)
        #     # probs = softmax(outputs)
        #     for j in range(len(probs)):
        #         new_sub_dataset.items.append(
        #             (first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))
        #         sub_dataset.items.append(
        #             (first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))

    else:
        next_diff_dataloader = torch.utils.data.DataLoader(
            next_diff_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        for i, (items, _, _, _) in enumerate(next_diff_dataloader):
            items = tuple([item.cuda() for item in items])
            with torch.no_grad():
                adv_items = data_gen(items)
            probs = query_victim(adv_items, victim, cip_ood_counter)
            #     outputs = victim(adv_items)
            # if opt.victim_wm_dataset:
            #     probs = softmax(outputs[-1])
            # else:
            #     probs = softmax(outputs)

            #  # truncate victim return
            # if victim_return_type == 'label':
            #     probs = probs.max(1)[1]
            # elif 'top' in victim_return_type:
            #     k = int(victim_return_type.split('-')[-1])
            #     topk_scores, topk_indices = probs.topk(k, dim=1)
            #     top_probs = torch.zeros_like(probs)
            #     top_probs.scatter_(1, topk_indices, topk_scores)
            #     probs = top_probs
            # elif 'round' in victim_return_type:
            #     k = int(victim_return_type.split('-')[-1])
            #     probs = torch.round(probs*10**k)/10**k

            for j in range(len(probs)):
                next_diff_dataset.items[i * batch_size + j] = \
                    (next_diff_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1)
        new_sub_dataset = next_diff_dataset
        sub_dataset.items = sub_dataset.items + next_diff_dataset.items

    # reset usage record for public data if there is not enough public data for next loop
    used = 0
    for item in unlabeled_dataset.items:
        used += item[3]
    if len(unlabeled_dataset) - used < math.ceil(opt.n_fuse*opt.query / opt.n_loop) * 1.5:
        for i in range(len(unlabeled_dataset)):
            unlabeled_dataset.items[i][3] = 0
    print(f'unused/total: {len(unlabeled_dataset)-used}/{len(unlabeled_dataset)}')

    return sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset

def get_random_aggr_gen_dataset(base_dataset, n_fuse, n_query, victim, data_gen,
                       sub_dataset, unlabeled_dataset, next_diff_dataset, dataset='cifar100'):
    new_sub_dataset = data.SubDataset()
    next_dataset = data.SubDataset()


    if opt.use_gpu:
        victim.cuda()
        data_gen.cuda()
    softmax = nn.Softmax(dim=1)
    batch_size = 5000#5000

    if len(sub_dataset) == 0 or opt.enlarge_every_loop:
        unlabeled_dataset.items = []
        for i in range(len(base_dataset)):
            unlabeled_dataset.items.append([base_dataset[i][0], -1, -1, 0])
        first_base_dataset = data.LabeledDataset()
        times = math.ceil(n_query / (len(unlabeled_dataset)//n_fuse))
        for _ in range(times):
            random_idx_list = random.sample([i for i in range(len(unlabeled_dataset))], len(unlabeled_dataset))
            for query in range(len(unlabeled_dataset)//n_fuse):
                items = []
                for i in range(n_fuse):
                    items.append(unlabeled_dataset.items[random_idx_list[query * n_fuse + i]][0].view(
                        3, opt.public_img_size, opt.public_img_size))
                    unlabeled_dataset.items[random_idx_list[query * n_fuse + i]][3] = 1
                first_base_dataset.items.append((items, 0))
        first_base_dataset.items = first_base_dataset.items[:n_query]

        first_base_dataloader = torch.utils.data.DataLoader(
            first_base_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        for i, (items, _) in enumerate(first_base_dataloader):
            noise = torch.randn(items[0].shape[0], opt.noise_dim).cuda()
            items = tuple([item.cuda() for item in items])
            with torch.no_grad():
                adv_items = data_gen(items, noise)
                outputs = victim(adv_items)
            if opt.victim_wm_dataset:
                probs = softmax(outputs[-1])
            else:
                probs = softmax(outputs)
            # probs = softmax(outputs)
            for j in range(len(probs)):
                new_sub_dataset.items.append(
                    (first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))
                sub_dataset.items.append(
                    (first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))

        # unlabeled_dataset.items = []
        # for i in range(len(base_dataset)):
        #     unlabeled_dataset.items.append([base_dataset[i][0], -1, -1, 0])
        # init_idx_list = random.sample([i for i in range(len(unlabeled_dataset))], n_query * n_fuse)
        # first_base_dataset = data.LabeledDataset()
        # for query in range(n_query):
        #     items = []
        #     for i in range(n_fuse):
        #         items.append(unlabeled_dataset.items[init_idx_list[query * n_fuse + i]][0].view(
        #             3, opt.public_img_size, opt.public_img_size))
        #         unlabeled_dataset.items[init_idx_list[query * n_fuse + i]][3] = 1
        #     first_base_dataset.items.append((items, 0))
        #
        # first_base_dataloader = torch.utils.data.DataLoader(
        #     first_base_dataset,
        #     batch_size=batch_size,
        #     num_workers=4
        # )
        # for i, (items, _) in enumerate(first_base_dataloader):
        #     items = tuple([item.cuda() for item in items])
        #     with torch.no_grad():
        #         adv_items = data_gen(items)
        #         outputs = victim(adv_items)
        #     if opt.victim_wm_dataset:
        #         probs = softmax(outputs[-1])
        #     else:
        #         probs = softmax(outputs)
        #     # probs = softmax(outputs)
        #     for j in range(len(probs)):
        #         new_sub_dataset.items.append(
        #             (first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))
        #         sub_dataset.items.append(
        #             (first_base_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1))

    else:
        next_diff_dataloader = torch.utils.data.DataLoader(
            next_diff_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        for i, (items, _, _, _) in enumerate(next_diff_dataloader):
            noise = torch.randn(items[0].shape[0], opt.noise_dim).cuda()
            items = tuple([item.cuda() for item in items])
            with torch.no_grad():
                adv_items = data_gen(items, noise)
                outputs = victim(adv_items)
            if opt.victim_wm_dataset:
                probs = softmax(outputs[-1])
            else:
                probs = softmax(outputs)
            # probs = softmax(outputs)
            for j in range(len(probs)):
                next_diff_dataset.items[i * batch_size + j] = \
                    (next_diff_dataset.items[i * batch_size + j][0], adv_items[j].cpu(), probs[j].cpu(), -1)
        new_sub_dataset = next_diff_dataset
        sub_dataset.items = sub_dataset.items + next_diff_dataset.items

    # reset usage record for public data if there is not enough public data for next loop
    used = 0
    for item in unlabeled_dataset.items:
        used += item[3]
    if len(unlabeled_dataset) - used < math.ceil(opt.n_fuse*opt.query / opt.n_loop) * 1.5:
        for i in range(len(unlabeled_dataset)):
            unlabeled_dataset.items[i][3] = 0
    print(f'unused/total: {len(unlabeled_dataset)-used}/{len(unlabeled_dataset)}')

    return sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset


def update_div_dataset(div_dataset, new_sub_dataset, div_threshold=0.1):
    cos = nn.CosineSimilarity(dim=0, eps=1e-6)
    mse = nn.MSELoss()
    for i in range(len(new_sub_dataset)):
        min_diff = 1e6
        for j in range(len(div_dataset)):
            _, div_data, _, _ = div_dataset[j]
            _, new_sub_data, _, _ = new_sub_dataset[i]
            # if opt.use_gpu:
            #     div_data = div_data.cuda()
            #     new_sub_data = new_sub_data.cuda()

            cos_sim = cos(div_data.flatten(0), new_sub_data.flatten(0))
            diff = 1-cos_sim
            # diff = mse(div_data, new_sub_data)
            # print(diff)

            if diff < min_diff:
                min_diff = diff
        if min_diff > div_threshold:
            div_dataset.items.append((new_sub_dataset[i]))
    return div_dataset