import collections
import csv
import logging
import random
import statistics

import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns
import torch.utils.data as data
import torchvision.transforms as transforms

from .datasets import Landmarks


def _read_csv(path: str):
    """Reads a csv file, and returns the content inside a list of dictionaries.
    Args:
      path: The path to the csv file.
    Returns:
      A list of dictionaries. Each row in the csv file will be a list entry. The
      dictionary is keyed by the column names.
    """
    with open(path, 'r') as f:
        return list(csv.DictReader(f))


# class Cutout(object):
#     def __init__(self, length):
#         self.length = length

#     def __call__(self, img):
#         h, w = img.size(1), img.size(2)
#         mask = np.ones((h, w), np.float32)
#         y = np.random.randint(h)
#         x = np.random.randint(w)

#         y1 = np.clip(y - self.length // 2, 0, h)
#         y2 = np.clip(y + self.length // 2, 0, h)
#         x1 = np.clip(x - self.length // 2, 0, w)
#         x2 = np.clip(x + self.length // 2, 0, w)

#         mask[y1: y2, x1: x2] = 0.
#         mask = torch.from_numpy(mask)
#         mask = mask.expand_as(img)
#         img *= mask
#         return img

# def _data_transforms_landmarks():
#     landmarks_MEAN = [0.5071, 0.4865, 0.4409]
#     landmarks_STD = [0.2673, 0.2564, 0.2762]

#     train_transform = transforms.Compose([
#         transforms.ToPILImage(),
#         transforms.RandomCrop(32, padding=4),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize(landmarks_MEAN, landmarks_STD),
#     ])

#     train_transform.transforms.append(Cutout(16))

#     valid_transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(landmarks_MEAN, landmarks_STD),
#     ])

#     return train_transform, valid_transform

class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_gld23_fednas():
    """
        the std 0.5 normalization is proposed by BiT (Big Transfer), which can increase the accuracy around 3%
    """
    CIFAR_MEAN = [0.5, 0.5, 0.5]
    CIFAR_STD = [0.5, 0.5, 0.5]

    """
        transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)) leads to a very low training accuracy.
    """
    # transform_train = transforms.Compose([
    #     # transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
    #     # transforms.ToPILImage(),
    #     transforms.Resize(224),
    #     transforms.RandomCrop(224),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    # ])
    # transform_test = transforms.Compose([
    #     # transforms.ToPILImage(),
    #     transforms.Resize((224, 224)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    # ])
    transform_train = transforms.Compose([
        # transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
        # transforms.ToPILImage(),
        transforms.Resize(32),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    transform_test = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return transform_train, transform_test




def _data_transforms_landmarks():
    # IMAGENET_MEAN = [0.5071, 0.4865, 0.4409]
    # IMAGENET_STD = [0.2673, 0.2564, 0.2762]

    IMAGENET_MEAN = [0.5, 0.5, 0.5]
    IMAGENET_STD = [0.5, 0.5, 0.5]

    image_size = 32
    train_transform = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    train_transform.transforms.append(Cutout(16))

    valid_transform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    return train_transform, valid_transform


def get_mapping_per_user(fn):
    """
    mapping_per_user is {'user_id': [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {}], 
                         'user_id': [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {}],
    } or               
                        [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ...  
                         {'user_id': xxx, 'image_id': xxx, 'class': xxx} ... ]
    }
    """
    mapping_table = _read_csv(fn)
    expected_cols = ['user_id', 'image_id', 'class']
    if not all(col in mapping_table[0].keys() for col in expected_cols):
        logging.error('%s has wrong format.', mapping_file)
        raise ValueError(
            'The mapping file must contain user_id, image_id and class columns. '
            'The existing columns are %s' % ','.join(mapping_table[0].keys()))

    data_local_num_dict = dict()

    mapping_per_user = collections.defaultdict(list)
    data_files = []
    net_dataidx_map = {}
    sum_temp = 0

    for row in mapping_table:
        user_id = row['user_id']
        mapping_per_user[user_id].append(row)
    for user_id, data in mapping_per_user.items():
        num_local = len(mapping_per_user[user_id])
        # net_dataidx_map[user_id]= (sum_temp, sum_temp+num_local)
        # data_local_num_dict[user_id] = num_local
        net_dataidx_map[int(user_id)] = (sum_temp, sum_temp + num_local)
        data_local_num_dict[int(user_id)] = num_local
        sum_temp += num_local
        data_files += mapping_per_user[user_id]
    assert sum_temp == len(data_files)

    return data_files, data_local_num_dict, net_dataidx_map


# for centralized training
def get_dataloader(dataset, datadir, train_files, test_files, train_bs, test_bs, train_dataidxs=None,
                   test_dataidxs=None):
    return get_dataloader_Landmarks(datadir, train_files, test_files, train_bs, test_bs, train_dataidxs, test_dataidxs)


# for local devices
def get_dataloader_test(dataset, datadir, train_files, test_files, train_bs, test_bs, dataidxs_train, dataidxs_test):
    return get_dataloader_test_Landmarks(datadir, train_files, test_files, train_bs, test_bs, dataidxs_train,
                                         dataidxs_test)


def get_dataloader_Landmarks(datadir, train_files, test_files, train_bs, test_bs, train_dataidxs=None,
                             test_dataidxs=None):
    dl_obj = Landmarks

    #transform_train, transform_test = _data_transforms_landmarks()
    transform_train, transform_test = _data_transforms_gld23_fednas()

    # train_ds = dl_obj(datadir, train_files, dataidxs=train_dataidxs, train=True, transform=transform_train, download=True)
    if not (test_dataidxs is None):  # Local
        train_ds = dl_obj(datadir, train_files, dataidxs=train_dataidxs, train = True, transform=transform_train,
                          download=True)
        test_ds = dl_obj(datadir, train_files, dataidxs=test_dataidxs, train = False, transform=transform_test,
                         download=True)
    else:
        # new_trainfiles = train_files[train_dataidxs[0]:train_dataidxs[1]]
        train_ds = dl_obj(datadir, train_files, dataidxs=train_dataidxs, train=True, transform=transform_train,
                          download=True)
        test_ds = dl_obj(datadir, test_files, dataidxs=test_dataidxs, train=False, transform=transform_test,
                         download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl


def get_dataloader_test_Landmarks(datadir, train_files, test_files, train_bs, test_bs, dataidxs_train=None,
                                  dataidxs_test=None):
    dl_obj = Landmarks

    #transform_train, transform_test = _data_transforms_landmarks()
    transform_train, transform_test = _data_transforms_gld23_fednas()

    train_ds = dl_obj(datadir, train_files, dataidxs=dataidxs_train, train=True, transform=transform_train,
                      download=True)
    test_ds = dl_obj(datadir, test_files, dataidxs=dataidxs_test, train=False, transform=transform_test, download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl


def load_partition_data_landmarks(dataset, data_dir, fed_train_map_file, fed_test_map_file,
                                  partition_method=None, partition_alpha=None, client_number=233, batch_size=10):
    train_files, data_local_num_dict, net_dataidx_map = get_mapping_per_user(fed_train_map_file)
    # logging.info(train_files) #image id, user id, class

    test_files = _read_csv(fed_test_map_file) #image id, class

    # # make x complete, y complete
    # res = {key: train_files[key] for key in train_files.keys()
    #        & {'image_id', 'class'}}
    # logging.info(res)
    # exit()


    class_num = len(np.unique([item['class'] for item in train_files]))
    # logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    train_data_num = len(train_files)

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, train_files, test_files, batch_size,
                                                         batch_size)
    # logging.info("train_dl_global number = " + str(len(train_data_global)))
    # logging.info("test_dl_global number = " + str(len(test_data_global)))
    test_data_num = len(test_files)

    # get local dataset
    data_local_num_dict = data_local_num_dict
    train_data_local_dict = dict()
    test_data_local_dict = dict()

    prob_dist = {}
    probb_dist = {}
    newclient_idx = 0
    heat_map_data = np.zeros((class_num, 28))
    for client_idx in range(client_number):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = dataidxs[1] - dataidxs[0]
        probb_dist[client_idx] = local_data_num
        # training batch size = 64; algorithms batch size = 32
        if local_data_num >= 200:
            train_len = int(0.68 * local_data_num)
            test_len = local_data_num - train_len
            train_dataidx = (dataidxs[0], dataidxs[0] + train_len)
            test_dataidx = (dataidxs[0] + train_len + 1, dataidxs[1])
            # Shuffle
            new_trainfiles = train_files[dataidxs[0]:dataidxs[1]]
            random.shuffle(new_trainfiles)
            train_files[dataidxs[0]:dataidxs[1]] = new_trainfiles
            logging.info(new_trainfiles)
            unique_values = {d['class'] for d in new_trainfiles}
            labels = []
            for d in new_trainfiles:
                labels.append(d['class'])
            logging.info(labels)
            valuess, counts = np.unique(labels, return_counts=True)
            logging.info("valuess = %s" % valuess)
            logging.info("counts = %s" % counts)
            # exit()
            for (i, j) in zip(valuess, counts):
                logging.info(i)
                logging.info(int(newclient_idx))
                heat_map_data[int(i)][int(newclient_idx)] = j / len(labels)
            train_data_local, test_data_local = get_dataloader(dataset, data_dir, train_files, test_files, batch_size,
                                                               batch_size,
                                                               train_dataidx, test_dataidx)


            train_data_local_dict[newclient_idx] = train_data_local
            test_data_local_dict[newclient_idx] = test_data_local
            prob_dist[newclient_idx] = local_data_num
            newclient_idx = newclient_idx + 1
            # logging.info("client_idx = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
            #      client_idx, len(train_data_local), len(test_data_local)))
        # train_data_local_dict[client_idx] = train_data_local
        # test_data_local_dict[client_idx] = test_data_local
        #
    plt.figure()
    fig_dims = (30, 10)
    # fig, ax = plt.subplots(figsize=fig_dims)
    # sns.set(font_scale=4)
    sns.heatmap(heat_map_data, linewidths=0.05, cmap="YlGnBu", cbar=True)
    plt.xlabel('Client number')
    plt.xlabel('Class label Number')
    # plt.ylabel('ratio of the specific label data w.r.t total dataset')
    # ax.tick_params(labelbottom=False, labelleft=False, labeltop=False, left=False, bottom=False, top=False)
    # fig.tight_layout(pad=0.1)
    plt.title("label distribution")
    plt.savefig('./gldheatmap')



    logging.info("New client Number:")
    new_client_number = newclient_idx
    # to plot
    plt.figure(0)
    logging.info(probb_dist)
    plt.hist(probb_dist.values(), bins=[0, 100, 200, 300, 400, 500, 600, 700, 800], color='g')
    plt.ylabel('Client number')
    plt.xlabel('local training dataset size')
    plt.title("Min = " + str(min(probb_dist.values())) + ", Max = " + str(max(probb_dist.values())))
    plt.text(300, 60, " Mean = " + str(statistics.mean(probb_dist.values())))
    plt.text(300, 50, ' STD = ' + str(
        statistics.stdev(probb_dist.values())))
    plt.savefig('./../../../gldorghist')
    logging.info('Figure saved')

    # logging.info(prob_dist)
    plt.figure(1)
    plt.bar(list(probb_dist.keys()), probb_dist.values(), color='g')
    plt.xlabel('Client number')
    plt.ylabel('local training dataset size')
    plt.title("Min = " + str(min(probb_dist.values())) + ", Max = " + str(max(probb_dist.values())))
    plt.text(100, 700, " Mean = " + str(statistics.mean(probb_dist.values())))
    plt.text(100, 750, ' STD = ' + str(
        statistics.stdev(probb_dist.values())))
    plt.savefig('./../../../gldorgplots')
    logging.info('Figure saved')

    # to plot
    plt.figure(2)
    logging.info(prob_dist)
    plt.hist(prob_dist.values(), bins=[0, 100, 200, 300, 400, 500, 600, 700, 800], color='g')
    plt.ylabel('Client number')
    plt.xlabel('local training dataset size')
    plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
    plt.text(400, 30, " Mean = " + str(statistics.mean(prob_dist.values())))
    plt.text(400, 27, ' STD = ' + str(
        statistics.stdev(prob_dist.values())))
    plt.savefig('./../../../gldnewhist')
    logging.info('Figure saved')

    # logging.info(prob_dist)
    plt.figure(3)
    plt.bar(list(prob_dist.keys()), prob_dist.values(), color='g')
    plt.xlabel('Client number')
    plt.ylabel('local training dataset size')
    plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
    plt.text(10, 700, " Mean = " + str(statistics.mean(prob_dist.values())))
    plt.text(10, 750, ' STD = ' + str(
        statistics.stdev(prob_dist.values())))
    plt.savefig('./../../../gldnewplots')
    logging.info('Figure saved')


    # logging.info(" Test Data Local  "+len(test_data_local) +" Train data Local "+len(test_data_local) +
    #              " Client ID  "+ client_idx)

    # logging("data_local_num_dict: %s" % data_local_num_dict)
    return train_data_num, test_data_num, train_data_global, test_data_global, \
           data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num, new_client_number


if __name__ == '__main__':
    # data_dir = './cache/images'
    data_dir = '../../../data/gld/images'
    fed_g23k_train_map_file = '../../../data/gld/data_user_dict/gld23k_user_dict_train.csv'
    fed_g23k_test_map_file = '../../../data/gld/data_user_dict/gld23k_user_dict_test.csv'

    fed_g160k_train_map_file = '../../../data/gld/data_user_dict/gld160k_user_dict_train.csv'
    fed_g160k_map_file = '../../../data/gld/data_user_dict/gld160k_user_dict_test.csv'

    dataset_name = 'g160k'

    if dataset_name == 'g23k':
        client_number = 233
        fed_train_map_file = fed_g23k_train_map_file
        fed_test_map_file = fed_g23k_test_map_file
    elif dataset_name == 'g160k':
        client_number = 1262
        fed_train_map_file = fed_g160k_train_map_file
        fed_test_map_file = fed_g160k_map_file

    train_data_num, test_data_num, train_data_global, test_data_global, \
    data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num = \
        load_partition_data_landmarks(None, data_dir, fed_train_map_file, fed_test_map_file,
                                      partition_method=None, partition_alpha=None, client_number=client_number,
                                      batch_size=62)

    print(train_data_num, test_data_num, class_num)
    print(data_local_num_dict)

    # i = 0
    # for data, label in train_data_global:
    #     # print(data)
    #     # print(label)
    #     print(len(data))
    #     print(len(label))
    #     i += 1
    #     if i > 5:
    #         break
    # print("=============================\n")

    for client_idx in range(client_number):
        i = 0
        for data, label in train_data_local_dict[client_idx]:
            print("Client ID ", client_idx, "Length of data ", len(data))
            # print(len(label))
            # print(data)
            # i += 1
            # if i > 5:
            #     break
