import pandas as pd
import multiprocessing as mp
from tensorflow.keras.datasets import mnist as mnist_dataset
import json
from NetworkClass import *
from NodeClass import *
from DSGDAlg import *
from Functions import *
import os
import datetime

###############config_variables###############
config_file_name = '1'
num_of_node = int(config_file_name)
sampling_f = 100 #ms
iteration = 10 ** 4
repeat_simulation = 1
child_process = False
mps=True
gpu = 0
num_worker=5
batch_size=64
cte = True
# cte = [iteration//2,iteration*3//4,5]
lr_exp = list(range(4,5))
iid = False
identity = ["cifar10",0,0] # chose between "cifar10", "SVHN"
H=50
alpha = 10
simulation_result_file_name = "simulation_result_consensus_distance" + config_file_name + "nodes_" + identity[0] +  "_iid" + str(iid) + "_H" +\
                              str(H) + "_alpha" + str(alpha) + "_iter" + str(iteration)
###############config_variables###############



delay_shift = 1




# load topology from config folder
with open("config/%s" % config_file_name, "r") as f:
    fp = f.readlines()
    node_connection = {int(k): v for k, v in json.loads(fp[0]).items()}
    connection_delay = {int(k): v for k, v in json.loads(fp[1]).items()}
    gap = json.loads(fp[2])
num_of_node = len(node_connection)
# data:
if identity[0] in ["mnist","w8a"]:
    if identity[0] == "mnist":
        (x_train, y_train), (x_test, y_test) = mnist_dataset.load_data()
        x_train = x_train / 255.0
        n_train, num_row, num_col = x_train.shape
        feature_pre_shuffle = np.reshape(x_train, (n_train, num_row * num_col))
        label_pre_shuffle = pd.get_dummies(y_train).values
        d = (label_pre_shuffle.shape[1], feature_pre_shuffle.shape[1])
        d_hat = (4, label_pre_shuffle.shape[1], feature_pre_shuffle.shape[1])

    elif identity[0] == "w8a":
        d = 300
        url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/w8a'
        df_train = pd.read_csv(url, names=['train'])
        url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/w8a.t'
        df_test = pd.read_csv(url, names=['test'])
        feature_pre_shuffle, label_pre_shuffle = fetch(df_train, 'train', d)

    model = np.zeros(d)
    shuffler = np.random.permutation(len(feature_pre_shuffle))
    feature = feature_pre_shuffle[shuffler]
    label = label_pre_shuffle[shuffler]
    total_data = len(label)
    data_split = []
    guid_to_split = range(0, total_data + 1, total_data // num_of_node)
    for node in range(num_of_node):
        data_split.append((guid_to_split[node], guid_to_split[node + 1]))

    if not iid:
        if identity[0]=="mnist":
            boolArr = np.argmax(label, axis=1) == 0
            x = feature[boolArr]
            y = label[boolArr]
            for n in range(1, 10):
                boolArr = np.argmax(label, axis=1) == n
                x1 = feature[boolArr]
                y1 = label[boolArr]
                x = np.vstack((x, x1))
                y = np.vstack((y, y1))
            feature = x
            label = y
        else:
            boolArr = label==1
            x = feature[boolArr]
            y = label[boolArr]
            boolArr = label == -1
            x1 = feature[boolArr]
            y1 = label[boolArr]
            x = np.vstack((x, x1))
            y = np.hstack((y, y1))
            feature = x
            label = y
        if num_of_node == 100:
            data_split = unbalanced_data_split(num_of_node, total_data, 10, 10)
        else:
            data_split = unbalanced_data_split(num_of_node, total_data, 500, 1)
    dataset = (feature, label)
    # indices = np.arange(total_data)
    device = None
    criterion = None
elif identity[0]=="quadratic":
    d = 1
    d_hat = (4, d)
    feature = list(range(num_of_node))
    label = list(range(num_of_node))
    dataset = (feature, label)
    data_split = [(i, i + 1) for i in range(num_of_node)]
    device = None
    criterion = None
elif identity[0]=="cifar10":
    if mps:
        device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    else:
        device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss()
    # NUM_FEATURES = 32 * 32
    # NUM_CLASSES = 100
    # GRAYSCALE = False
    # model = resnet18(NUM_CLASSES,GRAYSCALE).to(device)
    model = ResNet(ResidualBlock, [3, 3, 3]).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                                 train=True,
                                                 transform=transform,
                                                 download=True)
    test_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                                train=False,
                                                transform=transforms.Compose([transforms.ToTensor(),normalize]))
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=100,
                                              shuffle=False)
    total_data = len(train_dataset)
    indices = np.arange(total_data)
    indices = np.random.permutation(indices)
    data_split = []
    dataset = torch.utils.data.Subset(train_dataset, indices)
    guid_to_split = range(0, total_data + 1, total_data // num_of_node)
    for node in range(num_of_node):
        # data_split.append((0,total_data))
        data_split.append((guid_to_split[node], guid_to_split[node + 1]))
    if not iid:
        indices=[]
        for n in range(10):
            boolArr = np.array(train_dataset.targets) == n
            indices += list(np.where(boolArr)[0])
        dataset = torch.utils.data.Subset(train_dataset, indices)
        # if num_of_node == 100:
        #     data_split = unbalanced_data_split(num_of_node, total_data, 10, 10)
        # else:
        #     data_split = unbalanced_data_split(num_of_node, total_data, 500, 1)
        identity = ["cifar10", 0, 0]
elif identity[0] == "SVHN":
    if mps:
        device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    else:
        device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss()
    model = ResNet(ResidualBlock, [3, 3, 3], num_classes=100).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = torchvision.datasets.SVHN(root='./data/',
                                              split='train',
                                              transform=transform,
                                              download=True)
    test_dataset = torchvision.datasets.SVHN(root='./data/',
                                             split='test',
                                             download=True,
                                             transform=transforms.Compose([transforms.ToTensor(), normalize]))
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=100,
                                              shuffle=False)
    total_data = len(train_dataset)
    indices = np.arange(total_data)
    indices = np.random.permutation(indices)
    data_split = []
    dataset = torch.utils.data.Subset(train_dataset, indices)
    guid_to_split = range(0, total_data + 1, total_data // num_of_node)
    for node in range(num_of_node):
        # data_split.append((0,total_data))
        data_split.append((guid_to_split[node], guid_to_split[node + 1]))
    if not iid:
        indices = []
        for n in range(10):
            boolArr = np.array(train_dataset.labels) == n
            indices += list(np.where(boolArr)[0])
        dataset = torch.utils.data.Subset(train_dataset, indices)
        # if num_of_node == 100:
        #     data_split = unbalanced_data_split(num_of_node, total_data, 10, 10)
        # else:
        #     data_split = unbalanced_data_split(num_of_node, total_data, 500, 1)
    identity = ["cifar10", 0, 0]
elif identity[0]=="Mnist":
    if mps:
        device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    else:
        device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss()
    NUM_FEATURES = 32 * 32
    NUM_CLASSES = 10
    GRAYSCALE = True
    model = resnet18(NUM_CLASSES,GRAYSCALE).to(device)
    # model = ResNet(ResidualBlock, [3, 3, 3]).to(device)
    normalize = transforms.Normalize(mean=[0.485,],
                                     std=[0.229, ])
    transform = transforms.Compose([transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = torchvision.datasets.MNIST(root='./data/',
                                                 train=True,
                                                 transform=transform,
                                                 download=True)
    test_dataset = torchvision.datasets.MNIST(root='./data/',
                                                train=False,
                                                transform=transforms.Compose([transforms.ToTensor(),normalize]))
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=100,
                                              shuffle=False)
    total_data = len(train_dataset)
    indices = np.arange(total_data)
    indices = np.random.permutation(indices)
    data_split = []
    dataset = torch.utils.data.Subset(train_dataset, indices)
    guid_to_split = range(0, total_data + 1, total_data // num_of_node)
    for node in range(num_of_node):
        # data_split.append((0,total_data))
        data_split.append((guid_to_split[node], guid_to_split[node + 1]))
    if not iid:
        indices=[]
        for n in range(10):
            boolArr = np.array(train_dataset.targets) == n
            indices += list(np.where(boolArr)[0])
        dataset = torch.utils.data.Subset(train_dataset, indices)
        # if num_of_node == 100:
        #     data_split = unbalanced_data_split(num_of_node, total_data, 10, 10)
        # else:
        #     data_split = unbalanced_data_split(num_of_node, total_data, 500, 1)
    identity = ["cifar10", 0, 0]

#%%
def my_func(f,stream, args):
    network = Network(model, [], node_connection, 0, 0)
    np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
    for node in range(num_of_node):
        network.all_node.append(Node(identity, model, data_split[node],dataset,node_connection[node],
                                gap[node], connection_delay[node], stream, delay_shift,device,criterion,num_worker,batch_size))
    return f(*args, network)


if __name__ == '__main__':
    # mp.set_start_method('spawn')
    result = []

    ########layer###############
    stream = 1
    for exp in lr_exp:
        input_arg = (parallel_scaffold_layer,stream,(identity,total_data, dataset, exp, cte, H, iteration,
                                    sampling_f,device,criterion,num_worker,batch_size,test_loader,alpha))
        if not child_process:
            res=[]
            for run in range(repeat_simulation):
                res.append(my_func(*input_arg))
        else:
            with mp.Pool(repeat_simulation) as pool:
                res = pool.starmap_async(my_func, [input_arg for run in range(repeat_simulation)]).get()
        result.append(res)
        print("All done -", "Layer-wise - exp =",exp, repeat_simulation, "times at", datetime.datetime.now().strftime("%a, %d %B %Y %H:%M:%S"))
    ########layer###############



    with open("result/"+simulation_result_file_name, 'w') as f:
        f.write(json.dumps(result))