import collections
import csv
import logging
import os
import random
import statistics

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from .datasets import Landmarks
from ..augmentations import get_aug


def _plot(probb_dist, prob_dist):
    # to plot
    plt.figure(0)
    logging.info(probb_dist)
    plt.hist(probb_dist.values(), bins=[0, 100, 200, 300, 400, 500, 600, 700, 800], color='g')
    plt.ylabel('Client number')
    plt.xlabel('local training dataset size')
    plt.title("Min = " + str(min(probb_dist.values())) + ", Max = " + str(max(probb_dist.values())))
    plt.text(300, 60, " Mean = " + str(statistics.mean(probb_dist.values())))
    plt.text(300, 50, ' STD = ' + str(
        statistics.stdev(probb_dist.values())))
    plt.savefig('./../../../gldorghist')
    logging.info('Figure saved')

    # logging.info(prob_dist)
    plt.figure(1)
    plt.bar(list(probb_dist.keys()), probb_dist.values(), color='g')
    plt.xlabel('Client number')
    plt.ylabel('local training dataset size')
    plt.title("Min = " + str(min(probb_dist.values())) + ", Max = " + str(max(probb_dist.values())))
    plt.text(100, 700, " Mean = " + str(statistics.mean(probb_dist.values())))
    plt.text(100, 750, ' STD = ' + str(
        statistics.stdev(probb_dist.values())))
    plt.savefig('./../../../gldorgplots')
    logging.info('Figure saved')

    # to plot
    plt.figure(2)
    logging.info(prob_dist)
    plt.hist(prob_dist.values(), bins=[0, 100, 200, 300, 400, 500, 600, 700, 800], color='g')
    plt.ylabel('Client number')
    plt.xlabel('local training dataset size')
    plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
    plt.text(400, 30, " Mean = " + str(statistics.mean(prob_dist.values())))
    plt.text(400, 27, ' STD = ' + str(
        statistics.stdev(prob_dist.values())))
    plt.savefig('./gldnewhist')
    logging.info('Figure saved')

    # logging.info(prob_dist)
    plt.figure(3)
    plt.bar(list(prob_dist.keys()), prob_dist.values(), color='g')
    plt.xlabel('Client number')
    plt.ylabel('local training dataset size')
    plt.title("Min = " + str(min(prob_dist.values())) + ", Max = " + str(max(prob_dist.values())))
    plt.text(20, 700, " Mean = " + str(statistics.mean(prob_dist.values())))
    plt.text(20, 750, ' STD = ' + str(
        statistics.stdev(prob_dist.values())))
    plt.savefig('./gldnewplots')
    logging.info('Figure saved')


def _data_transforms_gld_ssl(ssl_method, transform_type):
    image_size = 32
    gld_mean_std = [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]
    if transform_type == "ssl":
        train_transform = get_aug(ssl_method, mean_std=gld_mean_std, image_size=image_size, train=True)
        test_transform = get_aug(ssl_method, mean_std=gld_mean_std, image_size=image_size, train=False,
                                 train_classifier=False)
    elif transform_type == "knn":
        train_transform = transforms.Compose([
            # transforms.ToTensor(),
            transforms.Resize((image_size, image_size)),
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5)),
        ])
        test_transform = transforms.Compose([
            # transforms.ToTensor(),
            transforms.Resize((image_size, image_size)),
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5)),
        ])
    elif transform_type == "linear":
        train_transform = get_aug(ssl_method, mean_std=gld_mean_std, image_size=image_size, train=False,
                                  train_classifier=True)
        test_transform = get_aug(ssl_method, mean_std=gld_mean_std, image_size=image_size, train=False,
                                 train_classifier=False)
    else:
        raise Exception("no such transform type")
    return train_transform, test_transform


def _read_csv(path: str):
    """Reads a csv file, and returns the content inside a list of dictionaries.
    Args:
      path: The path to the csv file.
    Returns:
      A list of dictionaries. Each row in the csv file will be a list entry. The
      dictionary is keyed by the column names.
    """
    with open(path, 'r') as f:
        return list(csv.DictReader(f))


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_gld23_fednas():
    """
        the std 0.5 normalization is proposed by BiT (Big Transfer), which can increase the accuracy around 3%
    """
    CIFAR_MEAN = [0.5, 0.5, 0.5]
    CIFAR_STD = [0.5, 0.5, 0.5]

    """
        transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)) leads to a very low training accuracy.
    """
    transform_train = transforms.Compose([
        # transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
        # transforms.ToPILImage(),
        transforms.Resize(224),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    transform_test = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return transform_train, transform_test


def _data_transforms_landmarks():
    # IMAGENET_MEAN = [0.5071, 0.4865, 0.4409]
    # IMAGENET_STD = [0.2673, 0.2564, 0.2762]

    IMAGENET_MEAN = [0.5, 0.5, 0.5]
    IMAGENET_STD = [0.5, 0.5, 0.5]

    image_size = 224
    train_transform = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    train_transform.transforms.append(Cutout(16))

    valid_transform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    return train_transform, valid_transform


def get_mapping_per_user(fn):
    """
    mapping_per_user is {'user_id': [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {}], 
                         'user_id': [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {}],
    } or               
                        [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ...  
                         {'user_id': xxx, 'image_id': xxx, 'class': xxx} ... ]
    }
    """
    mapping_table = _read_csv(fn)
    expected_cols = ['user_id', 'image_id', 'class']
    if not all(col in mapping_table[0].keys() for col in expected_cols):
        logging.error('%s has wrong format.', mapping_file)
        raise ValueError(
            'The mapping file must contain user_id, image_id and class columns. '
            'The existing columns are %s' % ','.join(mapping_table[0].keys()))

    data_local_num_dict = dict()

    mapping_per_user = collections.defaultdict(list)
    data_files = []
    net_dataidx_map = {}
    sum_temp = 0

    for row in mapping_table:
        user_id = row['user_id']
        mapping_per_user[user_id].append(row)
    for user_id, data in mapping_per_user.items():
        num_local = len(mapping_per_user[user_id])
        # net_dataidx_map[user_id]= (sum_temp, sum_temp+num_local)
        # data_local_num_dict[user_id] = num_local
        net_dataidx_map[int(user_id)] = (sum_temp, sum_temp + num_local)
        data_local_num_dict[int(user_id)] = num_local
        sum_temp += num_local
        data_files += mapping_per_user[user_id]
    assert sum_temp == len(data_files)

    return data_files, data_local_num_dict, net_dataidx_map


# for centralized training
def get_dataloader(dataset, datadir, ssl_method, transform_type, train_files, test_files, train_bs, test_bs,
                   train_dataidxs=None,
                   test_dataidxs=None):
    return get_dataloader_Landmarks(datadir, ssl_method, transform_type, train_files, test_files, train_bs, test_bs,
                                    train_dataidxs, test_dataidxs)


def get_dataloader_Landmarks(datadir, ssl_method, transform_type, train_files, test_files, train_bs, test_bs,
                             train_dataidxs=None,
                             test_dataidxs=None):
    dl_obj = Landmarks

    transform_train, transform_test = _data_transforms_gld_ssl(ssl_method, transform_type)

    if test_dataidxs is not None:
        train_ds = dl_obj(datadir, train_files, dataidxs=train_dataidxs, train=True, transform=transform_train,
                          download=True)
        # use part of the local train dataset the the test data
        test_ds = dl_obj(datadir, train_files, dataidxs=test_dataidxs, train=False, transform=transform_test,
                         download=True)
    else:
        train_ds = dl_obj(datadir, train_files, dataidxs=train_dataidxs, train=True, transform=transform_train,
                          download=True)
        # use the global test dataset
        test_ds = dl_obj(datadir, test_files, dataidxs=test_dataidxs, train=False, transform=transform_test,
                         download=True)
    #
    # train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True,
    #                            drop_last=False if transform_type == "knn" else True)
    # test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False,
    #                           drop_last=False if transform_type == "knn" else True)
    if transform_type == "knn":
        train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True,
                                   drop_last=False)
        test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False,
                                  drop_last=False)
    else:
        train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True,
                                   drop_last=True)
        test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False,
                                  drop_last=True)
    return train_dl, test_dl


def get_global_data_loader(args, batch_size):
    logging.info(args)
    dataset = args.dataset
    data_dir = os.path.join(args.data_dir, 'images')
    fed_train_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_train.csv')
    fed_test_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_test.csv')
    logging.info(fed_train_map_file)

    train_files, data_local_num_dict, net_dataidx_map = get_mapping_per_user(fed_train_map_file)
    test_files = _read_csv(fed_test_map_file)

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, "", "knn",
                                                         train_files, test_files,
                                                         batch_size, batch_size)
    return train_data_global, test_data_global


def load_partition_data_landmarks(dataset, data_dir, ssl_method, fed_train_map_file, fed_test_map_file,
                                  gradient_acc_steps,
                                  partition_method=None, partition_alpha=None, client_number=233, batch_size=10):
    train_files, data_local_num_dict, net_dataidx_map = get_mapping_per_user(fed_train_map_file)
    test_files = _read_csv(fed_test_map_file)

    class_num = len(np.unique([item['class'] for item in train_files]))
    # logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    train_data_num = len(train_files)

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, ssl_method, "linear", train_files,
                                                         test_files, batch_size,
                                                         batch_size)
    # logging.info("train_dl_global number = " + str(len(train_data_global)))
    # logging.info("test_dl_global number = " + str(len(test_data_global)))
    test_data_num = len(test_files)

    # get local dataset
    data_local_num_dict = data_local_num_dict
    train_data_local_dict = dict()
    test_data_local_dict = dict()

    prob_dist = {}
    probb_dist = {}
    newclient_idx = 0
    for client_idx in range(client_number):
        dataidxs = net_dataidx_map[client_idx]
        local_data_num = dataidxs[1] - dataidxs[0]
        probb_dist[client_idx] = local_data_num
        # training batch size = 64; algorithms batch size = 32
        if local_data_num >= 100:
            train_len = int(0.70 * local_data_num)
            test_len = local_data_num - train_len
            train_dataidx = (dataidxs[0], dataidxs[0] + train_len)
            test_dataidx = (dataidxs[0] + train_len + 1, dataidxs[1])
            # Shuffle
            new_trainfiles = train_files[dataidxs[0]:dataidxs[1]]
            random.shuffle(new_trainfiles)
            train_files[dataidxs[0]:dataidxs[1]] = new_trainfiles

            train_data_local, test_data_local = get_dataloader(dataset, data_dir, ssl_method, "ssl",
                                                               train_files, test_files,
                                                               batch_size, batch_size,
                                                               train_dataidx, test_dataidx)

            train_data_local_knn, test_data_local_knn = get_dataloader(dataset, data_dir, ssl_method, "knn",
                                                                       train_files, test_files,
                                                                       batch_size * gradient_acc_steps,
                                                                       batch_size * gradient_acc_steps,
                                                                       train_dataidx, test_dataidx)
            logging.info("len test_data_local_knn = %d" % len(test_data_local_knn))
            train_data_local_dict[newclient_idx] = (train_data_local, train_data_local_knn)
            test_data_local_dict[newclient_idx] = (test_data_local, test_data_local_knn)
            prob_dist[newclient_idx] = local_data_num
            newclient_idx = newclient_idx + 1
            # logging.info("client_idx = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
            #      client_idx, len(train_data_local), len(test_data_local)))
        # train_data_local_dict[client_idx] = train_data_local
        # test_data_local_dict[client_idx] = test_data_local
        #
    logging.info("New client Number:")
    new_client_number = newclient_idx

    _plot(probb_dist, prob_dist)
    # logging.info(" Test Data Local  "+len(test_data_local) +" Train data Local "+len(test_data_local) +
    #              " Client ID  "+ client_idx)

    # logging("data_local_num_dict: %s" % data_local_num_dict)
    return train_data_num, test_data_num, train_data_global, test_data_global, \
           data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num, new_client_number


if __name__ == '__main__':
    # data_dir = './cache/images'
    data_dir = '../../../data/gld/images'
    fed_g23k_train_map_file = '../../../data/gld/data_user_dict/gld23k_user_dict_train.csv'
    fed_g23k_test_map_file = '../../../data/gld/data_user_dict/gld23k_user_dict_test.csv'

    fed_g160k_train_map_file = '../../../data/gld/data_user_dict/gld160k_user_dict_train.csv'
    fed_g160k_map_file = '../../../data/gld/data_user_dict/gld160k_user_dict_test.csv'

    dataset_name = 'g160k'

    if dataset_name == 'g23k':
        client_number = 233
        fed_train_map_file = fed_g23k_train_map_file
        fed_test_map_file = fed_g23k_test_map_file
    elif dataset_name == 'g160k':
        client_number = 1262
        fed_train_map_file = fed_g160k_train_map_file
        fed_test_map_file = fed_g160k_map_file

    train_data_num, test_data_num, train_data_global, test_data_global, \
    data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num = \
        load_partition_data_landmarks(None, data_dir, fed_train_map_file, fed_test_map_file,
                                      partition_method=None, partition_alpha=None, client_number=client_number,
                                      batch_size=62)

    print(train_data_num, test_data_num, class_num)
    print(data_local_num_dict)

    # i = 0
    # for data, label in train_data_global:
    #     # print(data)
    #     # print(label)
    #     print(len(data))
    #     print(len(label))
    #     i += 1
    #     if i > 5:
    #         break
    # print("=============================\n")

    for client_idx in range(client_number):
        i = 0
        for data, label in train_data_local_dict[client_idx]:
            print("Client ID ", client_idx, "Length of data ", len(data))
            # print(len(label))
            # print(data)
            # i += 1
            # if i > 5:
            #     break
