from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy
from torch.utils.data import ConcatDataset


class DataLoader_fashion_noise(DataLoader):
    def __init__(self,
                 split_num=200,
                 pick_num=2,
                 batch_size=100,
                 input_require_shape=None,
                 shuffle=True,
                 pool_size=None,
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):
        if params is not None:
            if pool_size is not None:
                split_num = pool_size * params['N']
                pick_num = params['N']
            else:
                split_num = int(params['C'] * params['N'])
                pick_num = params['N']
            batch_size = params['batch_size']
            noise_ratio = params['NR']
        if split_num % pick_num != 0:
            raise RuntimeError('split_num must be divisible by the number of pick_num.')

        pool_size = split_num // pick_num
        name = 'Fashion_noise_pool_' + str(pool_size) + '_batchsize_' + str(
            batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape) + '_noise_ratio_' + str(noise_ratio)
        nickname = 'fashion_noise B' + str(batch_size) + ' S' + str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'
        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:

            transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.1307,), (0.3081,))
                 ])
            trainset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=True,
                                                         download=True, transform=transform)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
                                                      shuffle=True, num_workers=1)
            testset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=False,
                                                        download=True, transform=transform)
            testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
                                                     shuffle=False, num_workers=1)
            totalset = ConcatDataset([trainset, testset])
            totalloader = torch.utils.data.DataLoader(totalset, batch_size=len(totalset),
                                                      shuffle=True, num_workers=1)

            # modify
            # num_samples = len(trainset)
            # num_samples_to_modify = int(num_samples * noise_ratio)
            # indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
            # num_classes = 10
            # for idx in indices_to_modify:
            #     new_label = np.random.randint(0, num_classes)  # Generate random incorrect label
            #     while new_label == trainset.targets[idx]:  # Ensure the new label is different from the original one
            #         new_label = np.random.randint(0, num_classes)
            #     trainset.targets[idx] = new_label

            for i, (input_data, targets) in enumerate(trainloader):
                train_input_data = input_data
                train_target_data = targets
            for i, (input_data, targets) in enumerate(testloader):
                test_input_data = input_data
                test_target_data = targets

            for i, (input_data, targets) in enumerate(totalloader):
                total_input_data = input_data
                total_target_data = targets

            self.cal_data_shape(train_input_data.shape)
            # print(self.input_data_shape)
            # assert 1==0

            self.target_class_num = 10

            self.total_training_number = len(trainset)
            self.total_test_number = len(testset)

            self.output_size = 10
            self.model4data = 'mlp'
            self.task_name = 'fashion_classification'

            def LSR(target, ):
                num_samples = len(target)
                num_samples_to_modify = int(num_samples * noise_ratio)
                indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
                # print(indices_to_modify)
                # assert 1==0
                for idx in indices_to_modify:
                    new_label = np.random.randint(0, self.output_size)  # Generate random incorrect label
                    while new_label == target[idx]:  # Ensure the new label is different from the original one
                        new_label = np.random.randint(0, self.output_size)
                    target[idx] = new_label
                return target

            def create_data_pool(data_pool, input_data, target_data):
                order = torch.argsort(target_data)
                input_data = input_data[order, :]
                target_data = target_data[order]
                count = 0
                # 一份的量
                amount = input_data.shape[0] // split_num
                indices = list(range(input_data.shape[0]))
                split_data_indices_list = []
                # 把split_num份都分好
                for split_idx in range(split_num):
                    # 分配一份
                    start_idx = count
                    end_idx = count + amount
                    if end_idx > input_data.shape[0] - 1:
                        end_idx = input_data.shape[0] - 1
                    split_data_indices = indices[start_idx: end_idx]
                    split_data_indices_list.append(split_data_indices)
                    count += amount

                # 每个client选pick_num份
                for pool_idx in range(pool_size):
                    data_indices = []

                    for i in range(pick_num):
                        pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
                        data_indices += pick_data_indices
                        split_data_indices_list.remove(pick_data_indices)
                    random.shuffle(data_indices)
                    local_data_number = len(data_indices)
                    train_test_split_idx = int(len(data_indices) * self.total_training_number / (
                                self.total_training_number + self.total_test_number))

                    train_indices = data_indices[:train_test_split_idx]
                    test_indices = data_indices[train_test_split_idx:]

                    local_train_number = len(train_indices)
                    local_test_number = len(test_indices)

                    train_batch_data_indices_list = DataLoader.separate_list(train_indices, self.batch_size)
                    test_batch_data_indices_list = DataLoader.separate_list(test_indices, self.batch_size)

                    local_train_data = []
                    for batch_data_indices in train_batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()

                        batch_target_data = target_data[batch_data_indices]
                        batch_target_data = LSR(batch_target_data)
                        local_train_data.append((batch_input_data, batch_target_data))

                    local_test_data = []
                    for batch_data_indices in test_batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()

                        batch_target_data = target_data[batch_data_indices]
                        local_test_data.append((batch_input_data, batch_target_data))

                    data_pool[pool_idx]['local_training_data'] = local_train_data
                    data_pool[pool_idx]['local_test_data'] = local_test_data
                    data_pool[pool_idx]['local_training_number'] = local_train_number
                    data_pool[pool_idx]['local_test_number'] = local_test_number

                    data_pool[pool_idx]['data_name'] = str(pool_idx)

            data_pool = [{} for _ in range(self.pool_size)]
            # local_training local_test
            create_data_pool(data_pool, total_input_data, total_target_data)

            self.data_pool = data_pool

            np.save(file_path, self)

            # def create_data_pool(data_pool, input_data, target_data, key_name):
            #
            #     order = torch.argsort(target_data)
            #     input_data = input_data[order, :]
            #     target_data = target_data[order]
            #
            #     count = 0
            #     amount = input_data.shape[0] // split_num
            #     indices = list(range(input_data.shape[0]))
            #     split_data_indices_list = []
            #     for split_idx in range(split_num):
            #         start_idx = count
            #         end_idx = count + amount
            #         if end_idx > input_data.shape[0] - 1:
            #             end_idx = input_data.shape[0] - 1
            #         split_data_indices = indices[start_idx: end_idx]
            #         split_data_indices_list.append(split_data_indices)
            #         count += amount
            #     for pool_idx in range(pool_size):
            #         data_indices = []
            #
            #         for i in range(pick_num):
            #             pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
            #             data_indices += pick_data_indices
            #             split_data_indices_list.remove(pick_data_indices)
            #         random.shuffle(data_indices)
            #         local_data_number = len(data_indices)
            #
            #         batch_data_indices_list = DataLoader.separate_list(data_indices, self.batch_size)
            #         local_data = []
            #         for batch_data_indices in batch_data_indices_list:
            #             batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
            #             batch_target_data = target_data[batch_data_indices]
            #             local_data.append((batch_input_data, batch_target_data))
            #
            #         data_pool[pool_idx][key_name + '_data'] = local_data
            #         data_pool[pool_idx][key_name + '_number'] = local_data_number
            #         data_pool[pool_idx]['data_name'] = str(pool_idx)
            #
            # data_pool = [{} for _ in range(self.pool_size)]
            #
            # create_data_pool(data_pool, train_input_data, train_target_data, 'local_training')
            #
            # create_data_pool(data_pool, test_input_data, test_target_data, 'local_test')
            # self.data_pool = data_pool
            #
            # np.save(file_path, self)

    def allocate(self, client_list):

        # choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list), replace=False)
        choose_data_pool_item_indices = list(range(self.pool_size))

        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'])


import os, json
import gzip
import numpy as np

NAME = []


def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels


def generate_dataset():
    X_train, y_train = load_mnist('raw_data/fashion', kind='train')
    X_test, y_test = load_mnist('raw_data/fashion', kind='t10k')

    # some simple normalization
    mu = np.mean(X_train.astype(np.float32), 0)
    sigma = np.std(X_train.astype(np.float32), 0)

    X_train = (X_train.astype(np.float32) - mu) / (sigma + 0.001)
    X_test = (X_test.astype(np.float32) - mu) / (sigma + 0.001)

    return X_train.tolist(), y_train.tolist(), X_test.tolist(), y_test.tolist()


def main():
    train_output = "./train/mytrain.json"
    test_output = "./test/mytest.json"

    X_train, y_train, X_test, y_test = generate_dataset()

    # Create data structure
    train_data = {'users': [], 'user_data': {}, 'num_samples': []}
    test_data = {'users': [], 'user_data': {}, 'num_samples': []}

    # label 0: T-shirt(top); 2: pullover; 6: Shirt
    X_trains = [[] for i in range(10)]
    y_trains = [[] for i in range(10)]
    for idx, item in enumerate(X_train):
        i = y_train[idx]
        X_trains[i].append(X_train[idx])
        y_trains[i].append(y_train[idx])

    X_tests = [[] for i in range(10)]
    y_tests = [[] for i in range(10)]
    for idx, item in enumerate(X_test):
        i = y_test[idx]
        X_tests[i].append(X_test[idx])
        y_tests[i].append(y_test[idx])
    label_dict = {0: 'T-shirt', 2: 'pullover', 6: 'shirt'}
    selected = [0, 2, 6]
    cvt_labels = {}
    for i in range(len(selected)):
        cvt_labels[selected[i]] = i
    for i in selected:
        train_len = len(X_trains[i])
        print("training set for {}: {}".format(i, train_len))
        test_len = len(X_tests[i])
        uname = label_dict[i]
        train_data['users'].append(uname)
        train_data['user_data'][uname] = {'x': X_trains[i], 'y': [cvt_labels[lb] for lb in y_trains[i]]}
        train_data['num_samples'].append(train_len)
        test_data['users'].append(uname)
        test_data['user_data'][uname] = {'x': X_tests[i], 'y': [cvt_labels[lb] for lb in y_tests[i]]}
        test_data['num_samples'].append(test_len)

    with open(train_output, 'w') as outfile:
        json.dump(train_data, outfile)
    with open(test_output, 'w') as outfile:
        json.dump(test_data, outfile)


if __name__ == "__main__":
    main()