from utils.DataLoader import DataLoader
import utils
import os, glob
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy
from shutil import move
from torchvision import datasets


class DataLoader_tiny(DataLoader):

    def __init__(self,
                 pool_size=100,
                 alpha=0.1,
                 batch_size=100,
                 input_require_shape=[3, -1, -1],
                 shuffle=True,
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):

        if params is not None:
            pool_size = params['C']
            alpha = params['dir_a']
            batch_size = params['batch_size']

        name = 'TINY200_dir_pool_' + str(pool_size) + 'alpha_' + str(alpha) + '_batchsize_' + str(
            batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
        nickname = 'TINY dir B' + \
            str(batch_size) + ' alpha' + str(alpha) + ' N' + str(pool_size)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'

        def load_tiny():
            print(os.getcwd())
            data_dir = './utils/data/tiny-imagenet-200/'
            target_folder = './utils/data/tiny-imagenet-200/val/'
            val_dict = {}
            with open('./utils/data/tiny-imagenet-200/val/val_annotations.txt', 'r') as f:
                for line in f.readlines():
                    split_line = line.split('\t')
                    val_dict[split_line[0]] = split_line[1]

            paths = glob.glob('./utils/data/tiny-imagenet-200/val/images/*')
            for path in paths:
                file = path.split('/')[-1]
                folder = val_dict[file]
                if not os.path.exists(target_folder + str(folder)):
                    os.mkdir(target_folder + str(folder))
                    os.mkdir(target_folder + str(folder) + '/images')

            for path in paths:
                file = path.split('/')[-1]
                folder = val_dict[file]
                dest = target_folder + str(folder) + '/images/' + str(file)
                move(path, dest)
            num_label = 200
            normalize = transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
            transform_train = transforms.Compose(
                [transforms.RandomHorizontalFlip(), transforms.ToTensor(),
                 normalize, ])
            transform_test = transforms.Compose([transforms.ToTensor(), normalize, ])
            trainset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform_train)
            testset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=transform_test)
            return trainset, testset

        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:
            trainset, testset = load_tiny()
            train_img, train_target = [], []
            test_img, test_target = [], []
            for _, train_data in enumerate(trainset):
                img, target = train_data
                train_img.append(img.data.cpu().detach().numpy())
                train_target.append(np.array(target))
            for _, test_data in enumerate(testset):
                img, target = test_data
                test_img.append(img.data.cpu().detach().numpy())
                test_target.append(np.array(target))

            dataset_input = train_img + test_img
            dataset_label = train_target + test_target
            #
            # print(len(dataset_input))
            # print(dataset_input[0].shape)
            # assert 1==0
            self.input_data_shape = [3, 64, 64]
            dataset_input = np.array(dataset_input).reshape([-1] + self.input_data_shape)
            dataset_label = np.array(dataset_label)
            self.target_class_num = 200
            self.total_training_number = len(trainset)
            self.total_test_number = len(testset)

            self.task_name = 'tiny_classification'
            train_prob = 10 / 11


            def separate_data(data, train_prob, num_clients, target_class_num, batch_size, alpha):
                least_samples = batch_size / (1 - train_prob)
                X = [[] for _ in range(num_clients)]
                y = [[] for _ in range(num_clients)]
                statistic = [[] for _ in range(num_clients)]

                dataset_content, dataset_label = data

                dataidx_map = {}
                min_size = 0
                K = target_class_num
                N = len(dataset_label)

                while min_size < target_class_num:
                    idx_batch = [[] for _ in range(num_clients)]
                    for k in range(K):
                        idx_k = np.where(dataset_label == k)[0]
                        np.random.shuffle(idx_k)
                        proportions = np.random.dirichlet(
                            np.repeat(alpha, num_clients))

                        proportions = np.array(
                            [p * (len(idx_j) < N / num_clients) for p, idx_j in zip(proportions, idx_batch)])
                        proportions = proportions / proportions.sum()
                        proportions = (np.cumsum(proportions) *
                                       len(idx_k)).astype(int)[:-1]
                        idx_batch = [idx_j + idx.tolist() for idx_j,
                                                              idx in zip(idx_batch, np.split(idx_k, proportions))]
                        min_size = min([len(idx_j) for idx_j in idx_batch])

                for j in range(num_clients):
                    np.random.shuffle(idx_batch[j])
                    dataidx_map[j] = idx_batch[j]

                for client in range(num_clients):
                    idxs = dataidx_map[client]
                    X[client] = dataset_content[idxs]
                    y[client] = dataset_label[idxs]

                    for i in np.unique(y[client]):
                        statistic[client].append((int(i), int(sum(y[client] == i))))

                del data

                for client in range(num_clients):
                    print(f"Client {client}\t Size of data: {len(X[client])}\t Labels: ", np.unique(
                        y[client]))
                    print(f"\t\t Samples of labels: ", [i for i in statistic[client]])
                    print("-" * 50)

                return X, y, statistic

            def separate_list(input_list, n):
                def separate(input_list, n):
                    for i in range(0, len(input_list), n):
                        yield input_list[i: i + n]

                return list(separate(input_list, n))
            def seperate_batch_data(input_data, target_data, batch_size):
                batch_data_indices_list = separate_list(
                    list(range(len(target_data))), batch_size)
                local_data = []
                for batch_data_indices in batch_data_indices_list:
                    batch_input_data = input_data[batch_data_indices]

                    batch_target_data = target_data[batch_data_indices]
                    local_data.append((batch_input_data, batch_target_data))
                return local_data

            def create_data_pool(X, y, pool_size, shuffle, train_prob, batch_size, target_class_num):
                data_pool = [{} for _ in range(pool_size)]
                raw_test_data_pool = [{} for _ in range(pool_size)]
                for pool_idx in range(pool_size):
                    input_data = torch.Tensor(X[pool_idx]).float()
                    target_data = torch.Tensor(y[pool_idx]).long()
                    if shuffle:
                        indices = list(range(len(target_data)))
                        random.shuffle(indices)
                        input_data = input_data[indices]
                        target_data = target_data[indices]

                    training_input_data = input_data[:int(train_prob * len(target_data))]
                    training_target_data = target_data[:int(train_prob * len(target_data))]
                    test_input_data = input_data[int(train_prob * len(target_data)):]
                    test_target_data = target_data[int(train_prob * len(target_data)):]

                    data_pool[pool_idx]['local_training_data'] = seperate_batch_data(
                        training_input_data, training_target_data, batch_size)
                    data_pool[pool_idx]['local_training_number'] = len(
                        training_target_data)
                    data_pool[pool_idx]['data_name'] = str(pool_idx)
                    data_pool[pool_idx]['local_test_data'] = seperate_batch_data(
                        test_input_data, test_target_data, batch_size)
                    data_pool[pool_idx]['local_test_number'] = len(test_target_data)
                    raw_test_data_pool[pool_idx]['input_data'] = test_input_data
                    raw_test_data_pool[pool_idx]['target_data'] = test_target_data

                return data_pool
            X, y, statistic = separate_data((dataset_input, dataset_label), train_prob, pool_size, self.target_class_num, batch_size=batch_size, alpha=alpha)
            self.statistic = statistic
            self.data_pool = create_data_pool(X, y, pool_size, shuffle, train_prob, batch_size, self.target_class_num)
            np.save(file_path, self)

    def allocate(self, client_list):
        #
        # choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list),
        #                                                  replace=False)
        choose_data_pool_item_indices = list(range(self.pool_size))

        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'])



            # dataset_input.extend(trainset.data.cpu().detach().numpy())
            # dataset_input.extend(testset.data.cpu().detach().numpy())
            # dataset_label.extend(trainset.targets.cpu().detach().numpy())
            # dataset_label.extend(testset.targets.cpu().detach().numpy())
            #
            # self.cal_data_shape(trainset.data.shape)
            # dataset_input = np.array(dataset_input).reshape(
            #     [-1] + self.input_data_shape)
            # dataset_label = np.array(dataset_label)
            # self.target_class_num = 100
            # self.total_training_number = len(trainset)
            # self.total_test_number = len(testset)
            #
            # train_prob = trainset.data.shape[0] / \
            #     (trainset.data.shape[0] + testset.data.shape[0])
            #
            # X, y, statistic = fp.separate_data((dataset_input, dataset_label), train_prob, pool_size, self.target_class_num,
            #                                    item_classes_num=None, batch_size=batch_size, alpha=alpha, niid=True, balance=None, partition='dir')
            # self.statistic = statistic
            #
            # self.data_pool = fp.create_data_pool(
            #     X, y, pool_size, shuffle, train_prob, batch_size, self.target_class_num)
            #
            # np.save(file_path, self)
