import logging
import random
import torch

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch.utils.data as data
import torchvision.transforms as transforms

from .datasets import CIFAR10_truncated


def _plot_label_distribution(client_number, class_num, net_dataidx_map, y_complete):
    heat_map_data = np.zeros((class_num, client_number))

    for client_idx in range(client_number):
        idxx = net_dataidx_map[client_idx]
        logging.info("idxx = %s" % str(idxx))
        logging.info("y_train[idxx] = %s" % y_complete[idxx])

        valuess, counts = np.unique(y_complete[idxx], return_counts=True)
        logging.info("valuess = %s" % valuess)
        logging.info("counts = %s" % counts)
        # exit()
        for (i, j) in zip(valuess, counts):
            heat_map_data[i][int(client_idx)] = j / len(idxx)

    # data_dir = args.figure_path
    # fig_name = " cifar100+ "_%s_clients_heatmap_label.png" % args.partition_name"
    # fig_dir = os.path.join(data_dir, fig_name)
    plt.figure()
    fig_dims = (30, 10)
    # fig, ax = plt.subplots(figsize=fig_dims)
    # sns.set(font_scale=4)
    sns.heatmap(heat_map_data, linewidths=0.05, cmap="YlGnBu", cbar=True)
    plt.xlabel('Client number')
    # plt.ylabel('ratio of the specific label data w.r.t total dataset')
    # ax.tick_params(labelbottom=False, labelleft=False, labeltop=False, left=False, bottom=False, top=False)
    # fig.tight_layout(pad=0.1)
    plt.title("label distribution")
    plt.savefig('./newcifar100heatmap')
    # plt.show()


def _plot_sample_distribution(prob_dist):
    plt.figure(0)
    logging.info("list(prob_dist.keys()) = %s" % list(prob_dist.keys()))
    logging.info("prob_dist.values() = %s" % prob_dist.values())
    plt.bar(list(prob_dist.keys()), prob_dist.values(), color='g')
    plt.xlabel('Client number')
    plt.ylabel('local training dataset size')
    plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
    # plt.text(0, 1000, " Mean = " + str(statistics.mean(prob_dist.values())))
    # plt.text(0, 950, ' STD = ' + str(
    #     statistics.stdev(prob_dist.values())))
    plt.savefig('./newsample_distribution')
    # plt.show()
    logging.info('Figure saved')


def _data_transforms_cifar10_fednas():
    """
        the std 0.5 normalization is proposed by BiT (Big Transfer), which can increase the accuracy around 3%
    """
    CIFAR_MEAN = [0.5, 0.5, 0.5]
    CIFAR_STD = [0.5, 0.5, 0.5]

    """
        transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)) leads to a very low training accuracy.
    """
    transform_train = transforms.Compose([
        # transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
        transforms.ToPILImage(),
        transforms.Resize(224),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return transform_train, transform_test


def _data_transforms_cifar10_personalized_fednas():
    """
        the std 0.5 normalization is proposed by BiT (Big Transfer), which can increase the accuracy around 3%
    """
    CIFAR_MEAN = [0.5, 0.5, 0.5]
    CIFAR_STD = [0.5, 0.5, 0.5]

    """
        transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)) leads to a very low training accuracy.
    """
    transform_train = transforms.Compose([
        # transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
        transforms.ToPILImage(),
        transforms.Resize(244),
        transforms.RandomCrop(244),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((244, 244)),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return transform_train, transform_test

class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img

def _data_transforms_cifar10():
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    train_transform.transforms.append(Cutout(16))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    return train_transform, valid_transform

def load_cifar10_data(datadir):
    cifar10_complete_ds = CIFAR10_truncated(datadir, download=True)
    X_complete, y_complete = cifar10_complete_ds.data, cifar10_complete_ds.target
    return X_complete, y_complete


def partition_data_byclass(dataset, datadir, partition, n_nets, classes=5):
    logging.info("*********partition data  by classes***************")
    X_complete, y_complete = load_cifar10_data(datadir)
    logging.info(X_complete.shape)
    logging.info(y_complete.shape)

    K = 10
    N = y_complete.shape[0]
    logging.info("N = " + str(N))
    net_dataidx_map = {}

    ## Separate each class data
    idx_xlist = []
    for k in range(K):
        logging.info(np.where(y_complete == k)[0])
        idx_xlist.append(np.where(y_complete == k)[0])

    # N = 60000
    # n_nets = 20
    # classes = 5
    # K = 10
    images_perclass = int(N / (n_nets * classes))  # 600
    logging.info(" images per class " + str(images_perclass))

    logging.info("idx_xlist = %s" % str(idx_xlist))
    logging.info("idx_xlist[0] = %s" % str(idx_xlist[0]))
    idx_batch = [[] for _ in range(n_nets)]

    class_assigned_fully_set = set()
    class_available_set = set([i for i in range(K)])
    n_nets = n_nets

    for j in range(n_nets):
        logging.info("class_assigned_fully_set = %s" % str(list(class_assigned_fully_set)))
        logging.info("class_available_set = %s" % str(class_available_set))
        # get the remaining class IDs that have not been assigned
        if len(class_available_set) < classes:
            logging.info("j = %d" % j)
            for l in class_available_set:
                logging.info("len of class %d = %d" % (j, len(idx_xlist[l])))
                idx_batch[j] = idx_batch[j] + idx_xlist[l].tolist()  # look how to partition it.
        else:
            classes_picked = random.sample(list(class_available_set), classes)
            for l in classes_picked:
                idx_batch[j] = idx_batch[j] + idx_xlist[l][0:images_perclass].tolist()  # look how to partition it.
                idx_xlist[l] = idx_xlist[l][images_perclass:]
                if len(idx_xlist[l]) == 0:
                    if l not in class_assigned_fully_set:
                        class_assigned_fully_set.add(l)
                        class_available_set.difference_update(list(class_assigned_fully_set))
                        logging.info("no samples left in class %d" % l)

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    return X_complete, y_complete, net_dataidx_map, n_nets


def partition_data_by_lda(dataset, datadir, partition, n_nets, alpha):
    logging.info("*********partition data***************")
    x_train, y_train = load_cifar10_data(datadir)
    n_train = x_train.shape[0]
    # n_test = X_test.shape[0]

    min_size = 0
    K = 100
    logging.info(y_train)
    logging.info(y_train[0])
    N = y_train.shape[0]
    logging.info("N = " + str(N))
    net_dataidx_map = {}

    while min_size < 10:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(y_train == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
            ##  Balance
            proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    return x_train, y_train, net_dataidx_map, n_nets


def get_dataloader(dataset, datadir, train_bs, test_bs, train_dataidxs=None, test_dataidxs=None):
    return get_dataloader_CIFAR10(datadir, train_bs, test_bs, train_dataidxs, test_dataidxs)


def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None):
    dl_obj = CIFAR10_truncated
    # ViT transforms
    # transform_train, transform_test = _data_transforms_cifar10_fednas()
    # original fednas workshop
    transform_train, transform_test = _data_transforms_cifar10()

    train_ds = dl_obj(datadir, dataidxs=dataidxs_train, transform=transform_train, download=True)
    test_ds = dl_obj(datadir, dataidxs=dataidxs_test, transform=transform_test, download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl


def load_partition_data_cifar10(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size,
                                classes_perclient):
    if partition_method == "lda":
        #partition_data_by_lda(dataset, datadir, partition, n_nets, alpha)
        X_complete, y_complete, net_dataidx_map, n_nets = partition_data_by_lda(dataset, data_dir,
                                                                                 partition_method, client_number,
                                                                                 partition_alpha)
    else:
        X_complete, y_complete, net_dataidx_map, n_nets = partition_data_byclass(dataset, data_dir,
                                                                                 partition_method,
                                                                                 client_number, classes_perclient)
    class_num = len(np.unique(y_complete))
    train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)])
    # logging.warning("net_dataidx_map = %s" % str(net_dataidx_map))

    is_plot = False
    if is_plot:
        _plot_label_distribution(client_number, class_num, net_dataidx_map, y_complete)

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, batch_size, batch_size)
    test_data_num = len(test_data_global)
    # logging.info(test_data_num)
    # logging.info(len(train_data_global))
    # exit()

    # get local dataset
    data_local_num_dict = dict()
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    prob_dist = {}

    for client_idx in range(client_number):
        logging.info("client_idx = %d" % client_idx)
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = len(dataidxs)
        data_local_num_dict[client_idx] = local_data_num
        random.shuffle(dataidxs)  # we need to shuffle to make sure the test data contain various labels
        local_data_num = len(dataidxs)
        train_len = int(0.75 * local_data_num)

        # for plot
        prob_dist[client_idx] = local_data_num

        # train and test
        train_dataidxs = dataidxs[:train_len]

        # train and test
        test_dataidxs = dataidxs[train_len:]

        logging.info(len(train_dataidxs))
        logging.info(len(test_dataidxs))
        # exit()

        train_data_local, test_data_local = get_dataloader(dataset, data_dir, batch_size, batch_size,
                                                           train_dataidxs, test_dataidxs)

        train_data_local_dict[client_idx] = train_data_local
        test_data_local_dict[client_idx] = test_data_local

    if is_plot:
        _plot_sample_distribution(prob_dist)


    return train_data_num, test_data_num, train_data_global, test_data_global, \
           data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num, client_number
