from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy


class DataLoader_mnist(DataLoader):
    def __init__(self,
                 split_num=200,
                 pick_num=2,
                 batch_size=100,
                 input_require_shape=None,
                 shuffle=True,
                 pool_size=None,
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):
        if params is not None:
            if pool_size is not None:
                split_num = pool_size * params['N']
                pick_num = params['N']
            else:
                split_num = int(params['C'] * params['N'])
                pick_num = params['N']
            batch_size = params['batch_size']
        if split_num % pick_num != 0:
            raise RuntimeError('split_num must be divisible by the number of pick_num.')
        pool_size = split_num // pick_num
        name = 'MNIST_pool_' + str(pool_size) + 'split_' + str(split_num) + 'pick' + str(
            pick_num) + '_batchsize_' + str(batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
        nickname = 'MNIST B' + str(batch_size) + ' S' + str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'

        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:
            transform = transforms.Compose(
                [transforms.ToTensor()])
            trainset = torchvision.datasets.MNIST(root=utils.data_folder_path, train=True,
                                                         download=True, transform=transform)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
                                                      shuffle=True, num_workers=1)
            testset = torchvision.datasets.MNIST(root=utils.data_folder_path, train=False,
                                                        download=True, transform=transform)
            testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
                                                     shuffle=False, num_workers=1)
            # global_training_data = torch.utils.data.DataLoader(copy.deepcopy(trainset),
            #                                                    batch_size=self.batch_size,
            #                                                    shuffle=True, num_workers=1)
            # global_test_data = torch.utils.data.DataLoader(copy.deepcopy(testset),
            #                                                batch_size=self.batch_size,
            #                                                shuffle=False, num_workers=1)
            # # modify
            # num_samples = len(trainset)
            # noise_ratio = 0.5
            # num_samples_to_modify = int(num_samples * noise_ratio)
            # indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
            # num_classes = 10
            # for idx in indices_to_modify:
            #     new_label = np.random.randint(0, num_classes)  # Generate random incorrect label
            #     while new_label == trainset.targets[idx]:  # Ensure the new label is different from the original one
            #         new_label = np.random.randint(0, num_classes)
            #     trainset.targets[idx] = new_label

            for i, (input_data, targets) in enumerate(trainloader):
                train_input_data = input_data
                train_target_data = targets
            for i, (input_data, targets) in enumerate(testloader):
                test_input_data = input_data
                test_target_data = targets

            self.cal_data_shape(train_input_data.shape)

            self.target_class_num = 10

            self.global_training_data = []
            self.global_test_data = []
            # for (input_data, targets) in global_training_data:
            #     self.global_training_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
            # for (input_data, targets) in global_test_data:
            #     self.global_test_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
            self.total_training_number = len(trainset)
            self.total_test_number = len(testset)
            self.output_size = 10
            self.model4data = 'mlp'
            self.task_name = 'mnist_classification'

            def create_data_pool(data_pool, input_data, target_data, key_name):
                order = torch.argsort(target_data)
                input_data = input_data[order, :]
                target_data = target_data[order]

                count = 0
                amount = input_data.shape[0] // split_num
                indices = list(range(input_data.shape[0]))
                split_data_indices_list = []
                for split_idx in range(split_num):
                    start_idx = count
                    end_idx = count + amount
                    if end_idx > input_data.shape[0] - 1:
                        end_idx = input_data.shape[0] - 1
                    split_data_indices = indices[start_idx: end_idx]
                    split_data_indices_list.append(split_data_indices)
                    count += amount
                for pool_idx in range(pool_size):
                    data_indices = []

                    for i in range(pick_num):
                        pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
                        data_indices += pick_data_indices
                        split_data_indices_list.remove(pick_data_indices)
                    random.shuffle(data_indices)
                    local_data_number = len(data_indices)

                    batch_data_indices_list = DataLoader.separate_list(data_indices, self.batch_size)
                    local_data = []
                    for batch_data_indices in batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
                        batch_target_data = target_data[batch_data_indices]
                        local_data.append((batch_input_data, batch_target_data))

                    data_pool[pool_idx][key_name + '_data'] = local_data
                    data_pool[pool_idx][key_name + '_number'] = local_data_number
                    data_pool[pool_idx]['data_name'] = str(pool_idx)

            data_pool = [{} for _ in range(self.pool_size)]

            create_data_pool(data_pool, train_input_data, train_target_data, 'local_training')

            create_data_pool(data_pool, test_input_data, test_target_data, 'local_test')
            self.data_pool = data_pool
            np.save(file_path, self)

    def allocate(self, client_list):

        choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list), replace=False)
        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'])
# import os, json
# import gzip
# import numpy as np
#
# NAME=[]
# def load_mnist(path, kind='train'):
#
#
#     """Load MNIST data from `path`"""
#     labels_path = os.path.join(path,
#                                '%s-labels-idx1-ubyte.gz'
#                                % kind)
#     images_path = os.path.join(path,
#                                '%s-images-idx3-ubyte.gz'
#                                % kind)
#
#     with gzip.open(labels_path, 'rb') as lbpath:
#         labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
#                                offset=8)
#
#     with gzip.open(images_path, 'rb') as imgpath:
#         images = np.frombuffer(imgpath.read(), dtype=np.uint8,
#                                offset=16).reshape(len(labels), 784)
#
#     return images, labels
#
#
# def generate_dataset():
#
#   X_train, y_train = load_mnist('raw_data/fashion', kind='train')
#   X_test, y_test = load_mnist('raw_data/fashion', kind='t10k')
#
#
#   # some simple normalization
#   mu = np.mean(X_train.astype(np.float32), 0)
#   sigma = np.std(X_train.astype(np.float32), 0)
#
#   X_train = (X_train.astype(np.float32) - mu)/(sigma+0.001)
#   X_test = (X_test.astype(np.float32) - mu)/(sigma+0.001)
#
#   return X_train.tolist(), y_train.tolist(), X_test.tolist(), y_test.tolist()
#
#
# def main():
#     train_output = "./train/mytrain.json"
#     test_output = "./test/mytest.json"
#
#
#     X_train, y_train, X_test, y_test = generate_dataset()
#
#
#     # Create data structure
#     train_data = {'users': [], 'user_data':{}, 'num_samples':[]}
#     test_data = {'users': [], 'user_data':{}, 'num_samples':[]}
#
#
#     # label 0: T-shirt(top); 2: pullover; 6: Shirt
#     X_trains=[[] for i in range(10)]
#     y_trains = [[] for i in range(10)]
#     for idx, item in enumerate(X_train):
#         i=y_train[idx]
#         X_trains[i].append(X_train[idx])
#         y_trains[i].append(y_train[idx])
#
#     X_tests = [[] for i in range(10)]
#     y_tests = [[] for i in range(10)]
#     for idx, item in enumerate(X_test):
#         i=y_test[idx]
#         X_tests[i].append(X_test[idx])
#         y_tests[i].append(y_test[idx])
#     label_dict={0:'T-shirt', 2:'pullover', 6:'shirt'}
#     selected=[0,2,6]
#     cvt_labels= {}
#     for i in range(len(selected)):
#         cvt_labels[selected[i]]=i
#     for i in selected:
#         train_len=len(X_trains[i])
#         print("training set for {}: {}".format(i,train_len))
#         test_len = len(X_tests[i])
#         uname=label_dict[i]
#         train_data['users'].append(uname)
#         train_data['user_data'][uname] = {'x': X_trains[i], 'y': [cvt_labels[lb] for lb in y_trains[i]]}
#         train_data['num_samples'].append(train_len)
#         test_data['users'].append(uname)
#         test_data['user_data'][uname] = {'x': X_tests[i], 'y': [cvt_labels[lb] for lb in y_tests[i]]}
#         test_data['num_samples'].append(test_len)
#
#     with open(train_output,'w') as outfile:
#         json.dump(train_data, outfile)
#     with open(test_output, 'w') as outfile:
#         json.dump(test_data, outfile)
#
#
# if __name__ == "__main__":
#     main()