from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from numpy.random import RandomState
import numpy as np
import os

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from torch.utils.data import Dataset
import torch.utils.data as data_utils
import torch
from models import *
from params import args_parser
from data_preprocessing import *
import pickle
import collections

args = args_parser()

def load_dataset(args):
    data_path = args.dir+'/data/'+args.dataset
    if args.dataset == 'fmnist':
        trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        dataset_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=trans)
        dataset_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=trans)
        train_loader = DataLoader(dataset_train, batch_size=len(dataset_train))
        test_loader = DataLoader(dataset_test, batch_size=len(dataset_test))

        X_train = next(iter(train_loader))[0].numpy()
        Y_train = next(iter(train_loader))[1].numpy()
        X_test = next(iter(test_loader))[0].numpy()
        Y_test = next(iter(test_loader))[1].numpy()

        args.num_classes =10

        dataset_train, dataset_val, dataset_test = partition_imagedataset(X_train, Y_train, args)
        dataset_unseentrain, dataset_unseenval, dataset_unseentest = partition_unseenimagedataset(X_test, Y_test, args)

    elif args.dataset == 'cifar10':
        trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        dataset_train = datasets.CIFAR10(data_path, train=True, download=True, transform=trans)
        dataset_test = datasets.CIFAR10(data_path, train=False, download=True, transform=trans)

        train_loader = DataLoader(dataset_train, batch_size=len(dataset_train))
        test_loader = DataLoader(dataset_test, batch_size=len(dataset_test))

        X_train = next(iter(train_loader))[0].numpy()
        Y_train = next(iter(train_loader))[1].numpy()

        X_test = next(iter(test_loader))[0].numpy()
        Y_test = next(iter(test_loader))[1].numpy()

        dataset_train, dataset_val, dataset_test = partition_imagedataset(X_train, Y_train, args)
        dataset_unseentrain, dataset_unseenval, dataset_unseentest = partition_unseenimagedataset(X_test, Y_test, args)


    elif args.dataset == 'sent140':
        # load GloVe embeddings
        word2vectors, word2id = load_GloVe_twitter_emb()
        # load the twitter dataset and splits in train/val/test
        train, test, partition, partition_test, ratio = load_twitter_datasets()

        Xtrain, Ytrain = processAllTweets2vec(train, word2vectors)
        Xtest, Ytest = processAllTweets2vec(test, word2vectors)

        dataset_train, dataset_val, dataset_test = partition_textdataset(Xtrain, Ytrain, partition, args)
        dataset_unseentrain, dataset_unseenval, dataset_unseentest = partition_unseentextdataset(Xtest, Ytest, partition_test, args)


    return dataset_train, dataset_val, dataset_test, dataset_unseentrain, dataset_unseenval, dataset_unseentest

def __getDirichletData__(y, n, alpha, num_c):
    n_nets = n
    K = num_c

    labelList_true = y

    min_size = 0
    N = len(labelList_true)
    # np.random.seed(rnd)
    rnd = 0
    rann = RandomState(rnd)

    net_dataidx_map = {}

    p_client = np.zeros((n, num_c))

    for i in range(n):
        p_client[i] = rann.dirichlet(np.repeat(alpha, num_c))

    idx_batch = [[] for _ in range(n_nets)]

    for k in range(K):
        idx_k = np.where(labelList_true == k)[0]
        rann.shuffle(idx_k)

        proportions = p_client[:, k]

        proportions = proportions / proportions.sum()
        proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
        idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]

    for j in range(n_nets):
        rann.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(labelList_true[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp

    local_sizes = []
    for i in range(n_nets):
        local_sizes.append(len(net_dataidx_map[i]))
    local_sizes = np.array(local_sizes)
    weights = local_sizes / np.sum(local_sizes)

    print('Data statistics: %s' % str(net_cls_counts))
    print('Data ratio: %s' % str(weights))

    return idx_batch

def data_to_statistics(data, labelList_true):
    net_cls_counts = {}
    net_dataidx_map = {}
    for j in range(len(data)):
        net_dataidx_map[j] = data[j]

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(labelList_true[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp

    return net_cls_counts

def __getClusteredData__(Y_train, n_clients):


    labelList = np.array(Y_train)
    rann = RandomState(args.seed)

    if args.dataset == 'fmnist':
        num_cluster_clients = [10, 10, 20, 55, 5]
        labels_per_cluster = [2, 2, 2, 2, 2]
        alpha = [2,2,2,2,2]
    else:
        num_cluster_clients = [5, 5, 10, 40, 5, 5, 10, 5, 5, 10]
        labels_per_cluster = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        alpha = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

    idx_batch = [[] for _ in range(n_clients)]
    clients = list(np.arange(n_clients))
    data_idxs_within_cluster = []
    label_is = 0
    for i in range(len(num_cluster_clients)):
        idx_k = []
        for label_i in list(np.arange(label_is, label_is+labels_per_cluster[i])):
            idx_k += list(np.where(labelList == label_i)[0])
        label_is += labels_per_cluster[i]
        rann.shuffle(idx_k)
        idx_clust = idx_k
        data_idxs_within_cluster.append(idx_k)

        rand_idx = rann.choice(clients, size=num_cluster_clients[i], replace=False)
        p_dat = rann.power(alpha[i], num_cluster_clients[i])
        p_dat = p_dat / sum(p_dat)

        for ii, j in enumerate(rand_idx):
            idx_batch[j] = rann.choice(idx_clust, size=int(p_dat[ii] * len(idx_k)), replace=False)
            idx_clust = [ii for ii in idx_clust if ii not in idx_batch[j]]

        clients = [jj for jj in clients if jj not in rand_idx]

    return idx_batch




def select_model(args, model_n):
    # TODO: automate the arg num class
    if model_n == 'VGG11':
        model_t = vgg11(args.num_classes)

    elif model_n == 'VGG11bn':
        model_t = vgg11_bn(args.num_classes)

    elif model_n == 'VGG13':
        model_t = vgg13(args.num_classes)

    elif model_n == 'VGG13bn':
        model_t = vgg13_bn(args.num_classes)

    elif model_n == 'VGG16':
        model_t = vgg16(args.num_classes)

    elif model_n == 'VGG19':
        model_t = vgg19(args.num_classes)

    elif model_n == 'RES152':
        model_t = resnet152(pretrained=False, num_classes=args.num_classes)

    elif model_n == 'RES50':
        model_t = resnet50(pretrained=False, num_classes=args.num_classes)

    elif model_n == 'RES34':
        model_t = resnet34(pretrained=False, num_classes=args.num_classes)

    elif model_n == 'RES18':
        model_t = resnet18(pretrained=False, num_classes=args.num_classes)

    elif model_n == 'RES8':
        model_t = resnet8(pretrained=False, num_classes=args.num_classes)

    elif model_n == 'CNN':
        model_t = CNNCifar(args.num_classes)
        #model_t = LeNet5(args.num_classes)

    elif model_n == 'MLP':      # TODO: MLP needs a look before using
        model_t = MLP(28*28,62)

    elif model_n == 'MLP_sent':
        model_t = MLP_sent140(input_size=200, dim_hidden1=128, dim_hidden2 = 86, dim_hidden3 = 30, dim_out=args.num_classes)


    model_t.to(args.device)


    return model_t



def partition_imagedataset(X_train, Y_train, args):

    if args.isclust == 1:
        partition = __getClusteredData__(Y_train, args.ensize)

    else:
        partition = __getDirichletData__(Y_train, args.ensize, args.alpha, args.num_classes)



    dataset_train = []
    dataset_val = []
    dataset_test = []

    train_ratio = args.train_ratio
    val_ratio = args.val_ratio
    test_ratio = args.test_ratio

    for (i, ind) in enumerate(partition):

        x = X_train[ind]
        y = Y_train[ind]

        # x = np.squeeze(x)

        if args.isclust == 0:           # If no clustering, perform labelshigting

            # print (x.shape)
            if (i < 30):  ### For label shift experiments
                y = (y + 1) % 10

            elif (i >= 30 and i < 60):
                y = (y + 2) % 10

        # if(i>=65 and i<85):
        #   y = (y+3)%10

        # if(i>=85 and i<105):
        #   y = (y+4)%10

        # print (x.shape)

        n_i = len(ind)

        train_size = int(train_ratio * n_i)
        val_size = int(val_ratio * n_i)  # approx. 20% of val dataset
        test_size = int(test_ratio * n_i)

        # local_size = int(0.5*n_i)

        x_train = torch.Tensor(x[val_size:val_size + train_size])
        y_train = torch.LongTensor(y[val_size:val_size + train_size])
        dataset_train_torch = TensorDataset(x_train, y_train)

        if val_size == 0:
            x_val = x_train
            y_cal = y_train
            dataset_val_torch = dataset_train_torch
        else:
            x_val = torch.Tensor(x[:val_size])
            y_val = torch.LongTensor(y[:val_size])
            dataset_val_torch = TensorDataset(x_val, y_val)

        x_test = torch.Tensor(x[val_size + train_size:])
        y_test = torch.LongTensor(y[val_size + train_size:])
        dataset_test_torch = TensorDataset(x_test, y_test)

        #print(len(x_train))

        dataset_train.append(dataset_train_torch)
        dataset_test.append(dataset_test_torch)
        dataset_val.append(dataset_val_torch)

    return dataset_train, dataset_val, dataset_test





def partition_unseenimagedataset(X_train, Y_train, args):

    if args.isclust == 1:
        partition = __getClusteredData__(Y_train, int(args.ensize))

    else:
        partition = __getDirichletData__(Y_train, args.ensize, args.alpha, args.num_classes)

    dataset_train = []
    dataset_val = []
    dataset_test = []

    train_ratio = args.train_ratio
    val_ratio = args.val_ratio
    test_ratio = args.test_ratio


    for (i, ind) in enumerate(partition):

        x = X_train[ind]
        y = Y_train[ind]

        # x = np.squeeze(x)

        if args.isclust == 0:           # If no clustering, perform labelshigting

            # print (x.shape)
            if (i < 30):  ### For label shift experiments
                y = (y + 1) % 10

            elif (i >= 30 and i < 60):
                y = (y + 2) % 10

        n_i = len(ind)

        train_size = int(train_ratio * n_i)
        val_size = int(val_ratio * n_i)  # approx. 20% of val dataset
        test_size = int(test_ratio * n_i)

        # local_size = int(0.5*n_i)

        x_train = torch.Tensor(x[val_size:val_size + train_size])
        y_train = torch.LongTensor(y[val_size:val_size + train_size])
        dataset_train_torch = TensorDataset(x_train, y_train)

        if val_size == 0:
            x_val = x_train
            y_cal = y_train
            dataset_val_torch = dataset_train_torch
        else:
            x_val = torch.Tensor(x[:val_size])
            y_val = torch.LongTensor(y[:val_size])
            dataset_val_torch = TensorDataset(x_val, y_val)

        x_test = torch.Tensor(x[val_size + train_size:])
        y_test = torch.LongTensor(y[val_size + train_size:])
        dataset_test_torch = TensorDataset(x_test, y_test)

        # print(len(x_train))

        dataset_train.append(dataset_train_torch)
        dataset_test.append(dataset_test_torch)
        dataset_val.append(dataset_val_torch)

    return dataset_train, dataset_val, dataset_test



def partition_textdataset(X_train, Y_train, partition, args):
    #partition = __getDirichletData__(Y_train, args.ensize, args.alpha, args.num_classes)

    dataset_train = []
    dataset_val = []
    dataset_test = []

    train_ratio = args.train_ratio
    val_ratio = args.val_ratio
    test_ratio = args.test_ratio
    n_clients = len(partition)
    #print(X_train)
    for i in range(n_clients):

        ind = partition[i]
        x = X_train[ind]
        y = np.squeeze(Y_train[ind])
        n_i = len(ind)

        if (i < 100):  ### For label shift experiments
            y = (y + 1) % 3  # 1 2 0


        if args.dataset == 'sent140':

            train_size = int(train_ratio * n_i)
            val_size = int(val_ratio * n_i)  # approx. 20% of val dataset
            test_size = int(test_ratio * n_i)

            # local_size = int(0.5*n_i)

            x_train = torch.Tensor(x[val_size:val_size + train_size])
            y_train = torch.FloatTensor(y[val_size:val_size + train_size])
            dataset_train_torch = TensorDataset(x_train, y_train)

            if val_size == 0:
                dataset_val_torch = dataset_train_torch
            else:
                x_val = torch.Tensor(x[:val_size])
                y_val = torch.FloatTensor(y[:val_size])
                dataset_val_torch = TensorDataset(x_val, y_val)

            x_test = torch.Tensor(x[val_size + train_size:])
            y_test = torch.FloatTensor(y[val_size + train_size:])
            dataset_test_torch = TensorDataset(x_test, y_test)

            #print(len(x_train))

            dataset_train.append(dataset_train_torch)
            dataset_test.append(dataset_test_torch)
            dataset_val.append(dataset_val_torch)

    return dataset_train, dataset_val, dataset_test






def partition_unseentextdataset(X_train, Y_train, partition, args):
    #partition = __getDirichletData__(Y_train, args.ensize, args.alpha, args.num_classes)

    dataset_train = []
    dataset_val = []
    dataset_test = []

    train_ratio = args.train_ratio
    val_ratio = args.val_ratio
    test_ratio = args.test_ratio
    n_clients = len(partition)
    #print(X_train)
    for i in range(n_clients):


        ind = partition[i]
        x = X_train[ind]
        y = np.squeeze(Y_train[ind])
        n_i = len(ind)

        if (i < 100):  ### For label shift experiments
            y = (y + 1) % 3  # 1 2 0

        if args.dataset == 'sent140':

            train_size = int(train_ratio * n_i)
            val_size = int(val_ratio * n_i)  # approx. 20% of val dataset
            test_size = int(test_ratio * n_i)

            x_train = torch.Tensor(x[val_size:val_size + train_size])
            y_train = torch.FloatTensor(y[val_size:val_size + train_size])
            dataset_train_torch = TensorDataset(x_train, y_train)

            if val_size == 0:
                dataset_val_torch = dataset_train_torch
            else:

                x_val = torch.Tensor(x[:val_size])
                y_val = torch.FloatTensor(y[:val_size])
                dataset_val_torch = TensorDataset(x_val, y_val)


            x_test = torch.Tensor(x[val_size + train_size:])
            y_test = torch.FloatTensor(y[val_size + train_size:])
            dataset_test_torch = TensorDataset(x_test, y_test)

                # print(len(x_train))

            dataset_train.append(dataset_train_torch)
            dataset_test.append(dataset_test_torch)
            dataset_val.append(dataset_val_torch)

    return dataset_train, dataset_val, dataset_test


class CustomDataset(Dataset):
    def __init__(self, data_x, data_y):
        self.vocab, self.vocab_size, self.unk_symbol, self.pad_symbol = load_vocab()

        data_x_new = []
        data_y_new = []
        for c, l in zip(data_x, data_y):
            data_x_new.extend(c)
            data_y_new.extend(l["target_tokens"])

        for i, sentence in enumerate(data_x_new):
            data_x_new[i] = [self.vocab[x] for x in sentence]

        for i, sentence in enumerate(data_y_new):
            data_y_new[i] = [self.vocab[y] for y in sentence]

        self.data_x = [torch.Tensor(x).type(torch.LongTensor) for x in data_x_new]
        self.data_y = [torch.Tensor(y) for y in data_y_new]

    def __len__(self):
        return len(self.data_y)

    def __getitem__(self, i):
        return self.data_x[i], self.data_y[i]