from copy import deepcopy
from models.Abstract import Abstract
import os
import numpy as np
import time
import argparse
# import logging

from random import Random

import torch
import torch.distributed as dist
import torch.utils.data.distributed
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.multiprocessing import Process
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms


class Partition(object):
    """ Dataset-like object, but only access a subset of it. """

    def __init__(self, data, index):
        self.data = data
        self.index = index
        self.reformulate()
        self.targets = [self.data.targets[i] for i in self.index]
        self.labels = {}
        for i in self.index:
            if self.data[i][1] in self.labels:
                self.labels[self.data[i][1]] += 1.0 / len(self.index)
            else:
                self.labels[self.data[i][1]] = 1.0 / len(self.index)

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]

    def reformulate(self):
        while isinstance(self.data, Partition):
            self.data, self.index = self.data.data, [self.data.index[i] for i in self.index]


    def subset(self, index):
        new_index = [self.index[i] for i in index]
        new_set = Partition(self.data, new_index)
        return new_set

    def combine(self, new_set):
        new_index = self.index + new_set.index
        return Partition(self.data, new_index)

    def naive_core_set(self, number):
        index = list(np.random.randint(0, len(self.index), min(len(self.index), number)))
        return self.subset(index)

    def icarl_core_set(self, number, data, device, y):
        if number >= len(data):
            return self
        abstract_model = Abstract()
        abstract_model.to(device)
        abstract_model.eval()
        mean_feature_vector = None
        feature_vectors = []
        chosen = []
        for (X, Y) in data:
            if Y != y:
                feature_vectors.append(None)
                chosen.append(False)
                continue
            X = X.to(device)
            X = Variable(X)
            out = abstract_model(X)
            if mean_feature_vector == None:
                mean_feature_vector = out / len(data)
            else:
                mean_feature_vector += out / len(data)
            feature_vectors.append(out)
            chosen.append(False)
        index = []
        exist_mean_feature_vector = torch.zeros_like(mean_feature_vector)
        for i in range(number):
            target = 0
            for j in range(len(data)):
                if chosen[target] or feature_vectors[j] == None:
                    continue
                if feature_vectors[target] == None:
                    target = j
                else:
                    target = j if torch.norm(exist_mean_feature_vector + feature_vectors[j] - mean_feature_vector) < torch.norm(exist_mean_feature_vector + feature_vectors[target] - mean_feature_vector) else target
            if feature_vectors[target] != None:
                chosen[target] = True
                index.append(target)
        return self.subset(index)

        




    def weighted_core_set(self, number):
        number = min(len(self.index), number)
        P = torch.tensor([i + 1.0 for i in range(len(self.index))])
        index = list(torch.multinomial(P, number))
        return self.subset(index)

class DataPartitioner(object):
    """ Partitions a dataset into different chuncks. """
    def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234, isNonIID=False, alpha=0, dataset=None):
        self.data = data
        self.dataset = dataset
        if isNonIID:
            self.partitions, self.ratio = self.__getDirichletData__(data, sizes, seed, alpha)

        else:
            self.partitions = [] 
            self.ratio = sizes
            rng = Random() 
            rng.seed(seed) 
            data_len = len(data) 
            indexes = [x for x in range(0, data_len)] 
            rng.shuffle(indexes) 
             
     
            for frac in sizes: 
                part_len = int(frac * data_len)
                self.partitions.append(indexes[0:part_len])
                indexes = indexes[part_len:]

        

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])

    def __getNonIIDdata__(self, data, sizes, seed, alpha):
        labelList = data.train_labels
        rng = Random()
        rng.seed(seed)
        a = [(label, idx) for idx, label in enumerate(labelList)]
        # Same Part
        labelIdxDict = dict()
        for label, idx in a:
            labelIdxDict.setdefault(label,[])
            labelIdxDict[label].append(idx)
        labelNum = len(labelIdxDict)
        labelNameList = [key for key in labelIdxDict]
        labelIdxPointer = [0] * labelNum
        # sizes = number of nodes
        partitions = [list() for i in range(len(sizes))]
        eachPartitionLen= int(len(labelList)/len(sizes))
        # majorLabelNumPerPartition = ceil(labelNum/len(partitions))
        majorLabelNumPerPartition = 2
        basicLabelRatio = alpha

        interval = 1
        labelPointer = 0

        #basic part
        for partPointer in range(len(partitions)):
            requiredLabelList = list()
            for _ in range(majorLabelNumPerPartition):
                requiredLabelList.append(labelPointer)
                labelPointer += interval
                if labelPointer > labelNum - 1:
                    labelPointer = interval
                    interval += 1
            for labelIdx in requiredLabelList:
                start = labelIdxPointer[labelIdx]
                idxIncrement = int(basicLabelRatio*len(labelIdxDict[labelNameList[labelIdx]]))
                partitions[partPointer].extend(labelIdxDict[labelNameList[labelIdx]][start:start+ idxIncrement])
                labelIdxPointer[labelIdx] += idxIncrement

        #random part
        remainLabels = list()
        for labelIdx in range(labelNum):
            remainLabels.extend(labelIdxDict[labelNameList[labelIdx]][labelIdxPointer[labelIdx]:])
        rng.shuffle(remainLabels)
        for partPointer in range(len(partitions)):
            idxIncrement = eachPartitionLen - len(partitions[partPointer])
            partitions[partPointer].extend(remainLabels[:idxIncrement])
            rng.shuffle(partitions[partPointer])
            remainLabels = remainLabels[idxIncrement:]

        return partitions

    def __getDirichletData__(self, data, psizes, seed, alpha):
        n_nets = len(psizes)
        K = 10
        # print(dir(data))
        # print(data.targets)
        labelList = np.array(data.targets)
        min_size = 0
        N = len(labelList)
        np.random.seed(seed)

        net_dataidx_map = {}
        while min_size < 10:
            idx_batch = [[] for _ in range(n_nets)]
            # for each class in the dataset
            for k in range(K):
                idx_k = np.where(labelList == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
                ## Balance
                proportions = np.array([p*(len(idx_j)<N/n_nets) for p,idx_j in zip(proportions,idx_batch)])
                proportions = proportions/proportions.sum()
                proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])

        for j in range(n_nets):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]
            
        net_cls_counts = {}

        for net_i, dataidx in net_dataidx_map.items():
            unq, unq_cnt = np.unique(labelList[dataidx], return_counts=True)
            tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
            net_cls_counts[net_i] = tmp
        # print('Data statistics: %s' % str(net_cls_counts))

        local_sizes = []
        for i in range(n_nets):
            local_sizes.append(len(net_dataidx_map[i]))
        local_sizes = np.array(local_sizes)
        weights = local_sizes/np.sum(local_sizes)
        # print(weights)

        return idx_batch, weights


class SplitDataset():

    def __init__(self, args) -> None:
        self.client_num = args.client_num
        self.round_num = args.split_num
        self.client_drift = args.client_drift
        self.round_drift = args.round_drift

        self.trainset, self.test_loader = self.load_data(args)
        # self.trainset, self.test_loader = self.load_femnist_data(args)
        # self.trainset, self.test_loader = self.load_cifar100_data(args)

        if os.path.exists('./data/{}-{}-{}-data.pt'.format(self.client_num, self.round_num, self.round_drift)):
            self.combine_round_sets = torch.load('./data/{}-{}-{}-data.pt'.format(self.client_num, self.round_num, self.round_drift))
        else:
        
            self.round_sets = self.split(self.trainset, self.round_num * self.client_num, self.round_drift, args)
            self.combine_round_sets = []
            print('combing')
            for i in range(self.client_num):
                print(i)
                client_set = Partition(self.trainset, [])
                for j in range(self.round_num):
                    client_set = client_set.combine(self.round_sets[i * self.round_num + j])
                self.combine_round_sets.append(client_set)
            torch.save(self.combine_round_sets, './data/{}-{}-{}-data.pt'.format(self.client_num, self.round_num, self.round_drift))
            
        self.total_set = Partition(self.trainset, [i for i in range(len(self.trainset))])

    @classmethod
    def empty_set(cls, train_sets):
        return Partition(train_sets, [])

    def load_cifar100_data(self, args):
        CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
        CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
        CIFAR100_TEST_MEAN = (0.5088964127604166, 0.48739301317401956, 0.44194221124387256)
        CIFAR100_TEST_STD = (0.2682515741720801, 0.2573637364478126, 0.2770957707973042)
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR100_TEST_MEAN, CIFAR100_TEST_STD),
        ])
        trainset = torchvision.datasets.CIFAR100(root=args.datapath, 
                                            train=True, 
                                            download=True,
                                            transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=args.datapath, 
                                            train=False, 
                                            download=True,
                                            transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, 
                                            batch_size=64, 
                                            shuffle=False)
        return trainset, test_loader


    def load_femnist_data(self, args):
        trainset = torchvision.datasets.FashionMNIST(root=args.datapath, 
                                            train=True, 
                                            download=True,
                                            transform=transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.1307,), (0.3081))]))
        testset = torchvision.datasets.FashionMNIST(root=args.datapath, 
                                            train=False, 
                                            download=True,
                                            transform=transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.1307,), (0.3081))]))
        test_loader = torch.utils.data.DataLoader(testset, 
                                            batch_size=64, 
                                            shuffle=False)
        return trainset, test_loader
        




    def load_data(self, args):
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = torchvision.datasets.CIFAR10(root=args.datapath, 
                                            train=True, 
                                            download=True, 
                                            transform=transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        testset = torchvision.datasets.CIFAR10(root=args.datapath, 
                                        train=False, 
                                        download=True, 
                                        transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, 
                                            batch_size=64, 
                                            shuffle=False)
        return trainset, test_loader

    def split(self, data, size, alpha, args):
        partition_sizes = [1.0 / size for _ in range(size)]
        partition_sizes = [1.0 / size for _ in range(size)]
        partition = DataPartitioner(data, partition_sizes, isNonIID=args.NIID, alpha = alpha)
        ratio = partition.ratio
        partitions = partition.partitions
        train_sets = []
        for k in range(size):
            local_partition = Partition(data, partitions[k])
            train_sets.append(local_partition)
        return train_sets
            
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CIFAR-10 baseline')
    parser.add_argument('--client_num','-cN', 
                    default=10, 
                    type=int, 
                    help='the number of clients')
    parser.add_argument('--round_num','-rN', 
                    default=10, 
                    type=int, 
                    help='the number of communication rounds')
    parser.add_argument('--round_drift','-rd', 
                    default=1, 
                    type=float, 
                    help='round drift') 
    parser.add_argument('--client_drift','-cd', 
                    default=0.3, 
                    type=float, 
                    help='client drift')
    parser.add_argument('--lr', 
                    default=0.1, 
                    type=float, 
                    help='client learning rate')
    parser.add_argument('--bs', 
                    default=32, 
                    type=int, 
                    help='batch size on each worker/client')
    parser.add_argument('--NIID',
                    default=True,
                    action='store_true',
                    help='whether the dataset is non-iid or not')
    parser.add_argument('--datapath',
                    default='./data/',
                    type=str,
                    help='directory to load data')
    args = parser.parse_args()
    splited = SplitDataset(args)        

