import os
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import logging
import numpy as np
import torch.nn as nn

class DatasetIndex(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        data, target = self.dataset[index]

        return data, target, index

    def __len__(self):
        return len(self.dataset)

def image_train(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])


def image_test(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])


def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size

    txt_path = os.path.join(args.root_path, args.tar)

    data = datasets.ImageFolder(txt_path, transform=image_train())

    train_size = int(0.8 * len(data))
    test_size = len(data) - train_size
    dsets["target_tr"], dsets["target_te"] = torch.utils.data.random_split(data, [train_size, test_size])

    dsets["target_tr"] = DatasetIndex(dsets["target_tr"])
    dset_loaders["target_tr"] = DataLoader(dsets["target_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker,
                                           drop_last=False)

    dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker,
                                           drop_last=False)

    txt_path = os.path.join(args.root_path, args.tar)
    dsets["test"] = datasets.ImageFolder(txt_path, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 2, shuffle=True, num_workers=args.worker,
                                      drop_last=False)

    return dset_loaders

def load_partition_target_data(root_path, dir, batch_size, n_nets, seed=2020):
    dsets = {}

    dsets["target"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_train())
    dsets["test"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())

    total_num = len(dsets["target"].imgs)
    np.random.seed(seed)
    idxs = np.random.permutation(total_num)
    batch_idxs = np.array_split(idxs, n_nets)
    net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

    # print(net_dataidx_map, flush=True)

    # get local dataset
    data_local_num_dict = dict()
    train_data_local_dict = dict()
    test_data_local_dict = dict()

    test_data = torch.utils.data.DataLoader(dsets["test"], batch_size=batch_size * 3, shuffle=False,
                                            drop_last=False, num_workers=4)

    for client_idx in range(n_nets):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = len(dataidxs)
        data_local_num_dict[client_idx] = local_data_num
        logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

        # training batch size = 64; algorithms batch size = 32

        local_data = torch.utils.data.Subset(dsets["target"], net_dataidx_map[client_idx])

        # train_size = int(0.9 * len(local_data))
        # test_size = len(local_data) - train_size
        # dsets["target_tr"], dsets["target_te"] = torch.utils.data.random_split(local_data, [train_size, test_size])
        # local_train_data = DatasetIndex(dsets["target_tr"])

        local_train_data = DatasetIndex(local_data)
        train_data_local = torch.utils.data.DataLoader(local_train_data, batch_size=batch_size, shuffle=True,
                                                       drop_last=False, num_workers=4)

        local_data_test = torch.utils.data.Subset(dsets["test"], net_dataidx_map[client_idx])
        test_data_local = torch.utils.data.DataLoader(local_data_test, batch_size=batch_size * 3, shuffle=False,
                                                      drop_last=False, num_workers=4)

        logging.info("client_idx = %d, batch_num_train_local = %d" % (client_idx, len(train_data_local)))
        train_data_local_dict[client_idx] = train_data_local
        test_data_local_dict[client_idx] = test_data_local

    return total_num, data_local_num_dict, train_data_local_dict, test_data_local_dict, test_data

def load_target_data_multiple_devices(root_path, dir, batch_size, n_nets, process_id, seed=2020):
    dsets = {}
    client_idx = process_id - 1

    dsets["target"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_train())
    dsets["test"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())

    total_num = len(dsets["target"].imgs)
    np.random.seed(seed)
    idxs = np.random.permutation(total_num)
    batch_idxs = np.array_split(idxs, n_nets)
    net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

    test_data = torch.utils.data.DataLoader(dsets["test"], batch_size=batch_size * 3, shuffle=False,
                                            drop_last=False, num_workers=4)

    dataidxs = net_dataidx_map[client_idx]
    # print(dataidxs, flush=True)
    # print(len(test_data.dataset))
    local_data_num = len(dataidxs)
    logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

    # training batch size = 64; algorithms batch size = 32

    local_data = torch.utils.data.Subset(dsets["target"], net_dataidx_map[client_idx])

    local_train_data = DatasetIndex(local_data)
    train_data_local = torch.utils.data.DataLoader(local_train_data, batch_size=batch_size, shuffle=True,
                                                   drop_last=False, num_workers=4)

    local_data_test = torch.utils.data.Subset(dsets["test"], net_dataidx_map[client_idx])
    test_data_local = torch.utils.data.DataLoader(local_data_test, batch_size=batch_size * 3, shuffle=False,
                                                  drop_last=False, num_workers=4)

    return total_num, local_data_num, train_data_local, test_data_local, test_data


def load_target_data_DA(root_path, dir, batch_size, percent=0.8, seed=2020):
    dsets = {}

    dsets["target"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_train())
    dsets["test"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())

    target_data_num = len(dsets["target"].imgs)
    train_size = int(percent * target_data_num)
    test_size = target_data_num - train_size

    torch.manual_seed(seed)
    dsets["target_tr"], _ = torch.utils.data.random_split(dsets["target"], [train_size, test_size])

    torch.manual_seed(seed)
    dsets["target_ref"], dsets["target_te"] = torch.utils.data.random_split(dsets["test"], [train_size, test_size])

    dsets["target_tr"] = DatasetIndex(dsets["target_tr"])
    train_data = torch.utils.data.DataLoader(dsets["target_tr"], batch_size=batch_size, shuffle=True, drop_last=False,
                                             num_workers=4)

    ref_data = torch.utils.data.DataLoader(dsets["target_ref"], batch_size=batch_size*3, shuffle=False, drop_last=False,
                                           num_workers=4)

    # test_data = torch.utils.data.DataLoader(dsets["target_te"], batch_size=batch_size * 3, shuffle=False,
    #                                         drop_last=False, num_workers=4)
    test_data = torch.utils.data.DataLoader(dsets["test"], batch_size=batch_size * 3, shuffle=False,
                                            drop_last=False, num_workers=4)

    return train_size, train_data, ref_data, test_data


def load_partition_target_data_UD(root_path, dir, batch_size, n_nets, process_id, percent=0.8, seed=2020):
    dsets = {}
    client_idx = process_id-1

    dsets["target"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_train())
    dsets["test"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())

    target_data_num = len(dsets["target"].imgs)
    train_size = int(percent * target_data_num)
    test_size = target_data_num - train_size

    torch.manual_seed(seed)
    dsets["target_tr"], _ = torch.utils.data.random_split(dsets["target"], [train_size, test_size])

    # torch.manual_seed(seed)
    # _, dsets["target_te"] = torch.utils.data.random_split(dsets["test"], [train_size, test_size])


    np.random.seed(seed)
    idxs = np.random.permutation(train_size)
    batch_idxs = np.array_split(idxs, n_nets)
    net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

    # test_data = torch.utils.data.DataLoader(dsets["target_te"], batch_size=batch_size * 3, shuffle=False,
    #                                         drop_last=False, num_workers=4)

    test_data = torch.utils.data.DataLoader(dsets["test"], batch_size=batch_size * 3, shuffle=False,
                                            drop_last=False, num_workers=4)


    dataidxs = net_dataidx_map[client_idx]
    local_data_num = len(dataidxs)
    logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

    # training batch size = 64; algorithms batch size = 32

    local_data = torch.utils.data.Subset(dsets["target_tr"], net_dataidx_map[client_idx])

    local_train_data = DatasetIndex(local_data)
    print("Local train data: {}".format(local_train_data))
    train_data_local = torch.utils.data.DataLoader(local_train_data, batch_size=batch_size, shuffle=True,
                                                   drop_last=((len(local_train_data)%batch_size)==1), num_workers=4)
    logging.info("client_idx = %d, batch_num_train_local = %d" % (client_idx, len(train_data_local)))

    return train_size, local_data_num, train_data_local, test_data


def load_two_target_data_UD(root_path, dir, batch_size, percent=0.8, seed=2020):
    dsets = {}
    train_data_local = {}

    dsets["target"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_train())
    dsets["test"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())

    target_data_num = len(dsets["target"].imgs)
    train_size = int(percent * target_data_num)
    test_size = target_data_num - train_size

    torch.manual_seed(seed)
    dsets["target_tr"], _ = torch.utils.data.random_split(dsets["target"], [train_size, test_size])

    # torch.manual_seed(seed)
    # _, dsets["target_te"] = torch.utils.data.random_split(dsets["test"], [train_size, test_size])


    np.random.seed(seed)
    idxs = np.random.permutation(train_size)
    batch_idxs = np.array_split(idxs, 2)
    net_dataidx_map = {i: batch_idxs[i] for i in range(2)}

    # test_data = torch.utils.data.DataLoader(dsets["target_te"], batch_size=batch_size * 3, shuffle=False,
    #                                         drop_last=False, num_workers=4)

    test_data = torch.utils.data.DataLoader(dsets["test"], batch_size=batch_size * 3, shuffle=False,
                                            drop_last=False, num_workers=4)

    local_data_num = train_size

    local_data1 = torch.utils.data.Subset(dsets["target_tr"], net_dataidx_map[0])
    local_data2 = torch.utils.data.Subset(dsets["target_tr"], net_dataidx_map[1])

    local_train_data1 = DatasetIndex(local_data1)
    local_train_data2 = DatasetIndex(local_data2)
    train_data_local["first"] = torch.utils.data.DataLoader(local_train_data1, batch_size=batch_size, shuffle=True,
                                                            drop_last=False, num_workers=4)
    train_data_local["second"] = torch.utils.data.DataLoader(local_train_data2, batch_size=batch_size, shuffle=True,
                                                            drop_last=False, num_workers=4)


    return train_size, local_data_num, train_data_local, test_data


def load_source_data(root_path, dir, batch_size, seed=2020, drop_last=False):
    dsets = {}
    dset_loaders = {}
    train_bs = batch_size

    tr_path = os.path.join(root_path, dir)
    data_train = datasets.ImageFolder(root=tr_path, transform=image_train())
    data_test = datasets.ImageFolder(root=tr_path, transform=image_test())
    source_data_num = len(data_train.imgs)
    train_size = int(0.9 * len(data_train))
    print(train_size)
    test_size = source_data_num - train_size

    torch.manual_seed(seed)
    source_train_data, _ = torch.utils.data.random_split(data_train, [train_size, test_size])

    torch.manual_seed(seed)
    _, source_test_data = torch.utils.data.random_split(data_test, [train_size, test_size])

    dset_loaders["source_tr"] = DataLoader(source_train_data, batch_size=train_bs, shuffle=True,
                                           num_workers=4, drop_last=False if dir != 'RealWorld' else True)

    dset_loaders["source_te"] = DataLoader(source_test_data, batch_size=train_bs * 3, shuffle=False,
                                           num_workers=4, drop_last=False if dir != 'RealWorld' else True)

    # dset_loaders["source_tr"] = DataLoader(source_train_data, batch_size=train_bs, shuffle=True,
    #                                        num_workers=4, drop_last=drop_last)
    #
    # dset_loaders["source_te"] = DataLoader(source_test_data, batch_size=train_bs*3, shuffle=False,
    #                                        num_workers=4, drop_last=False)

    # dsets["source_te"] = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())
    # dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs*3, shuffle=False,
    #                                        num_workers=4, drop_last=False)

    return source_data_num, dset_loaders



def load_target_data_UD(args, percent=0.8, seed=2020):
    ## prepare data
    if percent > 1.0 or percent < 0.0:
        raise Exception('Value error for percent')

    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size

    txt_path = os.path.join(args.root_path, args.tar)

    dsets["target"] = datasets.ImageFolder(txt_path, transform=image_train())
    dsets["test"] = datasets.ImageFolder(txt_path, transform=image_test())
    data = datasets.ImageFolder(txt_path, transform=image_train())

    train_size = int(percent * len(data))
    test_size = len(data) - train_size
    torch.manual_seed(seed)

    dsets["target_tr"], _ = torch.utils.data.random_split(dsets["target"], [train_size, test_size])
    torch.manual_seed(seed)
    # _, dsets["target_te"] = torch.utils.data.random_split(dsets["test"], [train_size, test_size])

    # dsets["target_tr"] = DatasetIndex(dsets["target_tr"])
    dset_loaders["target_tr"] = DataLoader(dsets["target_tr"], batch_size=train_bs, shuffle=True,
                                           num_workers=args.worker, drop_last=False)

    # dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs * 3, shuffle=False,
    #                                        num_workers=args.worker, drop_last=False)

    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=False,
                                      num_workers=args.worker, drop_last=False)
    # txt_path = os.path.join(args.root_path, args.tar)
    # dsets["test"] = datasets.ImageFolder(txt_path, transform=image_test())
    # dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=True, num_workers=args.worker,
    #                                   drop_last=False)

    return dset_loaders

def select_target_training_data(args, netF, netB, netC, device, drop_last=False):
    train_bs = args.batch_size
    test_bs = 128
    dsets = {}
    dset_loaders = {}
    threshold = args.threshold

    txt_path = os.path.join(args.root_path, args.tar)
    dsets["target"] = datasets.ImageFolder(txt_path, transform=image_train())
    dsets["test"] = datasets.ImageFolder(txt_path, transform=image_test())

    dsets_temp = DatasetIndex(dsets["test"])
    train_data_temp = DataLoader(dsets_temp, batch_size=test_bs, shuffle=False, drop_last=False, num_workers=4)

    select_idx = None
    for images, labels, idx in train_data_temp:
        images = images.to(device)
        labels = labels.data.numpy()
        with torch.no_grad():
            output = netC(netB(netF(images)))
            softmax_out = nn.Softmax(dim=1)(output)
            pred_res = torch.max(softmax_out, 1)

            prob = pred_res[0].cpu().data.numpy()
            # pred_label = pred_res[1].cpu().data.numpy()

            prob_table = (prob >= threshold)
            # pred_label_table = (pred_label == labels)

            # print(prob_table)
            # print(true_table)

            batch_data_idx = idx[prob_table]
            select_idx = batch_data_idx if select_idx is None else np.concatenate((select_idx, batch_data_idx))


    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, shuffle=False,
                                      num_workers=4, drop_last=False)
    train_size = len(dsets["target"])
    if train_size % train_bs == 1:
        dset_loaders["train"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True,
                                           num_workers=4, drop_last=True)
    else:
        dset_loaders["train"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True,
                                           num_workers=4, drop_last=False)
    train_data = torch.utils.data.Subset(dsets["target"], select_idx)
    sel_size = len(train_data)
    if sel_size % train_bs == 1:
        dset_loaders["sel"] = DataLoader(train_data, batch_size=train_bs, shuffle=True,
                                         num_workers=4, drop_last=True)
    else:
        dset_loaders["sel"] = DataLoader(train_data, batch_size=train_bs, shuffle=True,
                                         num_workers=4, drop_last=False)

    return dset_loaders



