from torchvision import datasets, transforms
import torch
import os
import logging
import numpy as np
import torch.utils.data as data
# from .datasets import CIFAR100_truncated

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

def load_partition_target_data(root_path, dir, batch_size, phase, n_nets):
    transform_dict = {
        'src': transforms.Compose(
        [transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ]),
        'tar': transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])}
    data = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=transform_dict[phase])

    train_data_global = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False,
                                                    num_workers=4)

    train_data_num = len(data.imgs)
    total_num = train_data_num
    class_num = data.classes
    idxs = np.random.permutation(total_num)
    batch_idxs = np.array_split(idxs, n_nets)
    net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

    # get local dataset
    data_local_num_dict = dict()
    train_data_local_dict = dict()

    for client_idx in range(n_nets):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = len(dataidxs)
        data_local_num_dict[client_idx] = local_data_num
        logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

        # training batch size = 64; algorithms batch size = 32

        local_data = torch.utils.data.Subset(data, net_dataidx_map[client_idx])
        train_data_local = torch.utils.data.DataLoader(local_data, batch_size=batch_size, shuffle=True, drop_last=False,
                                                  num_workers=4)

        logging.info("client_idx = %d, batch_num_train_local = %d" % (client_idx, len(train_data_local)))
        train_data_local_dict[client_idx] = train_data_local

    return train_data_num, train_data_global, data_local_num_dict, train_data_local_dict, class_num


def load_source_data(root_path, dir, batch_size, phase):
    transform_dict = {
        'src': transforms.Compose(
        [transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ]),
        'tar': transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])}
    data = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=transform_dict[phase])

    train_data_global = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False,
                                                    num_workers=4)

    train_data_num = len(data.imgs)
    # total_num = train_data_num
    class_num = data.classes


    return train_data_num, train_data_global, class_num
