import logging
import random

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch.utils.data as data
from torchvision import transforms

from .datasets import CIFAR10_truncated
# generate the non-IID distribution for all methods
from ..augmentations import get_aug

import pickle


def _plot_label_distribution(client_number, class_num, net_dataidx_map, y_complete):
    heat_map_data = np.zeros((class_num, client_number))

    for client_idx in range(client_number):
        idxx = net_dataidx_map[client_idx]
        logging.info("idxx = %s" % str(idxx))
        logging.info("y_train[idxx] = %s" % y_complete[idxx])

        valuess, counts = np.unique(y_complete[idxx], return_counts=True)
        logging.info("valuess = %s" % valuess)
        logging.info("counts = %s" % counts)
        # exit()
        for (i, j) in zip(valuess, counts):
            heat_map_data[i][int(client_idx)] = j / len(idxx)

    # data_dir = args.figure_path
    # fig_name = " cifar100+ "_%s_clients_heatmap_label.png" % args.partition_name"
    # fig_dir = os.path.join(data_dir, fig_name)
    plt.figure()
    fig_dims = (30, 10)
    # fig, ax = plt.subplots(figsize=fig_dims)
    # sns.set(font_scale=4)
    sns.heatmap(heat_map_data, linewidths=0.05, cmap="YlGnBu", cbar=True)
    # ax.tick_params(labelbottom=False, labelleft=False, labeltop=False, left=False, bottom=False, top=False)
    # fig.tight_layout(pad=0.1)
    plt.title("label distribution")
    plt.savefig('./cifar100heatmap')
    # plt.show()


def _plot_sample_distribution(prob_dist):
    plt.figure(0)
    logging.info("list(prob_dist.keys()) = %s" % list(prob_dist.keys()))
    logging.info("prob_dist.values() = %s" % prob_dist.values())
    plt.bar(list(prob_dist.keys()), prob_dist.values(), color='g')
    plt.xlabel('Client number')
    plt.ylabel('local training dataset size')
    plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
    # plt.text(0, 1000, " Mean = " + str(statistics.mean(prob_dist.values())))
    # plt.text(0, 950, ' STD = ' + str(
    #     statistics.stdev(prob_dist.values())))
    plt.savefig('./sample_distribution')
    # plt.show()
    logging.info('Figure saved')


def record_net_data_stats(y_train, net_dataidx_map):
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
    logging.debug('Data statistics: %s' % str(net_cls_counts))
    return net_cls_counts


def _data_transforms_cifar100_ssl(ssl_method, transform_type):
    cifar10_mean_std = [[0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768]]
    if transform_type == "ssl":
        train_transform = get_aug(ssl_method, mean_std=cifar10_mean_std, image_size=32, train=True)
        test_transform = get_aug(ssl_method, mean_std=cifar10_mean_std, image_size=32, train=False,
                                 train_classifier=False)
    elif transform_type == "knn":
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                 (0.24703233, 0.24348505, 0.26158768)),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                 (0.24703233, 0.24348505, 0.26158768)),
        ])
    elif transform_type == "linear":
        train_transform = get_aug(ssl_method, mean_std=cifar10_mean_std, image_size=32, train=False,
                                  train_classifier=True)
        test_transform = get_aug(ssl_method, mean_std=cifar10_mean_std, image_size=32, train=False,
                                 train_classifier=False)
    else:
        raise Exception("no such transform type")
    return train_transform, test_transform


def load_cifar10_data(datadir):
    cifar100_complete_ds = CIFAR10_truncated(datadir, download=True)
    X_complete, y_complete = cifar100_complete_ds.data, cifar100_complete_ds.targets
    return X_complete, y_complete


def get_dataloader(dataset, datadir, ssl_method, transform_type, train_bs, test_bs, train_dataidxs=None,
                   test_dataidxs=None):
    return get_dataloader_CIFAR10(datadir, ssl_method, transform_type, train_bs, test_bs, train_dataidxs,
                                  test_dataidxs)


def get_dataloader_CIFAR10(datadir, ssl_method, transform_type, train_bs, test_bs, train_dataidxs=None,
                           test_dataidxs=None):
    dl_obj = CIFAR10_truncated
    transform_train, transform_test = _data_transforms_cifar100_ssl(ssl_method, transform_type)

    train_ds = dl_obj(datadir, dataidxs=train_dataidxs, train=True, transform=transform_train, download=True)
    test_ds = dl_obj(datadir, dataidxs=test_dataidxs, train=False, transform=transform_test, download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl


def __partition_data_by_lda(dataset, datadir, partition, n_nets, alpha):
    logging.info("*********partition data***************")
    X_train, y_train = load_cifar10_data(datadir)
    n_train = X_train.shape[0]
    # n_test = X_test.shape[0]

    min_size = 0
    K = 10
    logging.info(y_train)
    logging.info(y_train[0])
    N = y_train.shape[0]
    logging.info("N = " + str(N))
    net_dataidx_map = {}

    while min_size < 10:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(y_train == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
            ##  Balance
            proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map)

    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    if comm.Get_rank() == 0:
        pickle.dump(net_dataidx_map, open('net_dataidx_map.pk', 'wb'))
        pickle.dump(traindata_cls_counts, open('traindata_cls_counts.pk', 'wb'))
    exit()

    return X_train, y_train, net_dataidx_map, traindata_cls_counts


def partition_data_by_lda(dataset, datadir, partition, n_nets, alpha):
    logging.info("*********partition data***************")
    X_train, y_train = load_cifar10_data(datadir)

    assert n_nets == 10, f'No predefined split for n={n_nets}'
    assert partition == 'lda', f'No predefined split for partition={partition}'
    assert dataset == 'cifar10', f'No predefined split for dataset={dataset}'

    net_dataidx_map = pickle.load(open(f'./fixed_split/net_dataidx_map_a{alpha}.pk', 'rb'))
    traindata_cls_counts = pickle.load(open(f'./fixed_split/traindata_cls_counts_a{alpha}.pk', 'rb'))

    return X_train, y_train, net_dataidx_map, traindata_cls_counts


def partition_data_byclass(dataset, datadir, partition, n_nets, classes=5):
    logging.info("*********partition data  by classes***************")
    X_complete, y_complete = load_cifar10_data(datadir)
    logging.info(X_complete.shape)
    logging.info(y_complete.shape)

    K = 100
    N = y_complete.shape[0]
    logging.info("N = " + str(N))
    net_dataidx_map = {}

    ## Separate each class data
    idx_xlist = []
    for k in range(K):
        logging.info(np.where(y_complete == k)[0])
        idx_xlist.append(np.where(y_complete == k)[0])

    images_perclass = int(N / (n_nets * classes))  # 600
    logging.info(" images per class " + str(images_perclass))

    logging.info("idx_xlist = %s" % str(idx_xlist))
    logging.info("idx_xlist[0] = %s" % str(idx_xlist[0]))
    idx_batch = [[] for _ in range(n_nets)]

    class_assigned_fully_set = set()
    class_available_set = set([i for i in range(K)])
    n_nets = n_nets

    for j in range(n_nets):
        logging.info("class_assigned_fully_set = %s" % str(list(class_assigned_fully_set)))
        logging.info("class_available_set = %s" % str(class_available_set))
        # get the remaining class IDs that have not been assigned
        if len(class_available_set) < classes:
            logging.info("j = %d" % j)
            for l in class_available_set:
                logging.info("len of class %d = %d" % (j, len(idx_xlist[l])))
                idx_batch[j] = idx_batch[j] + idx_xlist[l].tolist()  # look how to partition it.
        else:
            classes_picked = random.sample(list(class_available_set), classes)
            for l in classes_picked:
                idx_batch[j] = idx_batch[j] + idx_xlist[l][0:images_perclass].tolist()  # look how to partition it.
                idx_xlist[l] = idx_xlist[l][images_perclass:]
                if len(idx_xlist[l]) == 0:
                    if l not in class_assigned_fully_set:
                        class_assigned_fully_set.add(l)
                        class_available_set.difference_update(list(class_assigned_fully_set))
                        logging.info("no samples left in class %d" % l)

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    return X_complete, y_complete, net_dataidx_map, n_nets


def load_partition_data_cifar10_ssl(dataset, data_dir, ssl_method,
                                    partition_method, partition_alpha, client_number, batch_size, gradient_acc_steps):
    logging.info("------------------------------------------------------")
    logging.info("dataset = %s, data_dir = %s, ssl_method = %s , partition_method = %s, partition_alpha = %f, "
                 "client_number = %d, batch_size = %d" %
                 (dataset, data_dir, ssl_method, partition_method, partition_alpha, client_number, batch_size))
    if partition_method == "lda":
        X_train, y_train, net_dataidx_map, traindata_cls_counts = partition_data_by_lda(dataset,
                                                                                        data_dir,
                                                                                        partition_method,
                                                                                        client_number,
                                                                                        partition_alpha)
    else:
        X_train, y_train, net_dataidx_map, traindata_cls_counts = partition_data_byclass(dataset,
                                                                                         data_dir,
                                                                                         partition_method,
                                                                                         client_number,
                                                                                         partition_alpha)
    class_num = len(np.unique(y_train))
    logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)])

    ################################ PLOT ###############################################
    is_plot = True
    if is_plot:
        heat_map_data = np.zeros((class_num, client_number))

        for client_idx in range(client_number):
            idxx = net_dataidx_map[client_idx]
            logging.info("idxx = %s" % str(idxx))
            logging.info("y_train[idxx] = %s" % y_train[idxx])

            valuess, counts = np.unique(y_train[idxx], return_counts=True)
            logging.info("valuess = %s" % valuess)
            logging.info("counts = %s" % counts)
            # exit()
            for (i, j) in zip(valuess, counts):
                heat_map_data[i][int(client_idx)] = j / len(idxx)

        # data_dir = args.figure_path
        # fig_name = " cifar100+ "_%s_clients_heatmap_label.png" % args.partition_name"
        # fig_dir = os.path.join(data_dir, fig_name)
        plt.figure()
        fig_dims = (30, 10)
        fig, ax = plt.subplots(figsize=fig_dims)
        # sns.set(font_scale=4)
        sns.heatmap(heat_map_data, linewidths=0.05, cmap="YlGnBu", cbar=True)
        # ax.tick_params(labelbottom=False, labelleft=False, labeltop=False, left=False, bottom=False, top=False)
        # fig.tight_layout(pad=0.1)
        plt.title("alpha = " + str(partition_alpha))
        plt.savefig('./cifar10heatmap')
        # plt.show()
    #######################################################################################

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, ssl_method, "linear",
                                                         batch_size, batch_size)

    logging.info("train_dl_global number = " + str(len(train_data_global)))
    logging.info("test_dl_global number = " + str(len(train_data_global)))
    test_data_num = len(test_data_global)

    # get local dataset
    data_local_num_dict = dict()
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    prob_dist = {}

    for client_idx in range(client_number):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = len(dataidxs)
        data_local_num_dict[client_idx] = local_data_num
        prob_dist[client_idx] = local_data_num

        random.shuffle(dataidxs)
        local_data_num = len(dataidxs)
        train_len = int(0.70 * local_data_num)
        train_dataidxs = dataidxs[:train_len]
        test_dataidxs = dataidxs[train_len:]

        data_local_num_dict[client_idx] = local_data_num
        logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

        # training batch size = 64; algorithms batch size = 32
        train_data_local, test_data_local = get_dataloader(dataset, data_dir, ssl_method, "ssl",
                                                           batch_size, batch_size,
                                                           train_dataidxs, test_dataidxs)

        train_data_local_knn, test_data_local_knn = get_dataloader(dataset, data_dir, ssl_method, "knn",
                                                                   batch_size * gradient_acc_steps, batch_size * gradient_acc_steps,
                                                                   train_dataidxs, test_dataidxs)

        logging.info("client_idx = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
            client_idx, len(train_data_local), len(test_data_local)))
        train_data_local_dict[client_idx] = (train_data_local, train_data_local_knn)
        test_data_local_dict[client_idx] = (test_data_local, test_data_local_knn)
    if is_plot:
        plt.figure(0)
        logging.info("list(prob_dist.keys()) = %s" % list(prob_dist.keys()))
        logging.info("prob_dist.values() = %s" % prob_dist.values())
        plt.bar(list(prob_dist.keys()), prob_dist.values(), color='g')
        plt.xlabel('Client number')
        plt.ylabel('local training dataset size')
        plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
        # plt.text(0, 1000, " Mean = " + str(statistics.mean(prob_dist.values())))
        # plt.text(0, 950, ' STD = ' + str(
        #     statistics.stdev(prob_dist.values())))
        plt.savefig('./sample_distribution')
        # plt.show()
        logging.info('Figure saved')

    return train_data_num, test_data_num, train_data_global, test_data_global, \
           data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num


def load_partition_data_cifar10_ssl_linear_eval(dataset, data_dir, ssl_method,
                                                partition_method, partition_alpha, client_number, batch_size, gradient_acc_steps):
    logging.info("------------------------------------------------------")
    logging.info("dataset = %s, data_dir = %s, ssl_method = %s , partition_method = %s, partition_alpha = %f, "
                 "client_number = %d, batch_size = %d" %
                 (dataset, data_dir, ssl_method, partition_method, partition_alpha, client_number, batch_size))
    if partition_method == "lda":
        X_train, y_train, net_dataidx_map, traindata_cls_counts = partition_data_by_lda(dataset,
                                                                                        data_dir,
                                                                                        partition_method,
                                                                                        client_number,
                                                                                        partition_alpha)
    else:
        X_train, y_train, net_dataidx_map, traindata_cls_counts = partition_data_byclass(dataset,
                                                                                         data_dir,
                                                                                         partition_method,
                                                                                         client_number,
                                                                                         partition_alpha)
    class_num = len(np.unique(y_train))
    logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)])

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, ssl_method, "linear",
                                                         batch_size, batch_size)

    logging.info("train_dl_global number = " + str(len(train_data_global)))
    logging.info("test_dl_global number = " + str(len(train_data_global)))
    test_data_num = len(test_data_global)

    # get local dataset
    data_local_num_dict = dict()
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    prob_dist = {}

    for client_idx in range(client_number):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = len(dataidxs)
        data_local_num_dict[client_idx] = local_data_num
        prob_dist[client_idx] = local_data_num

        random.shuffle(dataidxs)
        local_data_num = len(dataidxs)
        train_len = int(0.70 * local_data_num)
        train_dataidxs = dataidxs[:train_len]
        test_dataidxs = dataidxs[train_len:]

        data_local_num_dict[client_idx] = local_data_num
        logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

        # training batch size = 64; algorithms batch size = 32
        train_data_local, test_data_local = get_dataloader(dataset, data_dir, ssl_method, "linear",
                                                           batch_size, batch_size,
                                                           train_dataidxs, test_dataidxs)

        train_data_local_knn, test_data_local_knn = get_dataloader(dataset, data_dir, ssl_method, "knn",
                                                                   batch_size * gradient_acc_steps, batch_size * gradient_acc_steps,
                                                                   train_dataidxs, test_dataidxs)

        logging.info("client_idx = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
            client_idx, len(train_data_local), len(test_data_local)))
        train_data_local_dict[client_idx] = (train_data_local, train_data_local_knn)
        test_data_local_dict[client_idx] = (test_data_local, test_data_local_knn)

    return train_data_num, test_data_num, train_data_global, test_data_global, \
           data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num
