import os
from fedbase.baselines import *
from fedbase.model.model import *
from fedbase.nodes.node import node
from fedbase.utils.tools import unpack_args
from fedbase.utils.data_loader import data_process
import torch.optim as optim
import torch.nn as nn
import torch
from functools import partial
import numpy as np
import torch.multiprocessing as mp
import time
import random

os.chdir(os.path.dirname(os.path.abspath(__file__)))
global_rounds = 100
num_nodes = 200
local_steps = 10
batch_size = 32
optimizer = partial(optim.SGD, lr=0.001, momentum=0.9)
device = torch.device('cuda:0')

@unpack_args
def main1(seed, dataset_splited, model):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    fedavg.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device)
    fedprox.run(dataset_splited, batch_size, num_nodes, model,  nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.1, device = device)

@unpack_args
def main2(seed, dataset_splited, model, K):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    fedavg_ensemble.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, K, device = device)
    fedprox_ensemble.run(dataset_splited, batch_size, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, 0.1, K, device = device)
    fesem.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device)
    ifca.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, global_rounds, local_steps, device = device)
    ifca_res.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, 30, global_rounds, local_steps, device = device)
    fesem_res.run(dataset_splited, batch_size, K, num_nodes, model, nn.CrossEntropyLoss, optimizer, 30, global_rounds, local_steps, reg_lam = 0.001, device = device)

  
# multiprocessing
if __name__ == '__main__':
    multi_processes = 24
    seed_0 = 20
    seeds = 5

    start = time.perf_counter()
    with mp.get_context('spawn').Pool(multi_processes) as p:
        # cluster_wise
        p.map(main2, [(i, data_process(dataset).split_dataset_groupwise(n0,j0,k0,n1,j1,k1), model, K) for i in range(seed_0, seed_0+seeds) for \
        dataset, model in zip(['cifar10', 'fashion_mnist', 'medmnist_pathmnist', 'medmnist_tissuemnist'],[CNNCifar, CNNFashion_Mnist, CNNPath, CNNTissue]) \
        for n0,n1 in zip([10],[20]) for j0, k0, j1, k1 in zip([3, 0.1], ['class', 'dirichlet'], [2, 10], ['class', 'dirichlet']) for K in [5,10]])
        p.close()

        # client_wise
        p.map(main2, [(i, data_process(dataset).split_dataset(num_nodes, j, k), model, K) for i in range(seed_0, seed_0+seeds) for \
        dataset, model in zip(['cifar10', 'fashion_mnist', 'medmnist_pathmnist', 'medmnist_tissuemnist'],[CNNCifar, CNNFashion_Mnist, CNNPath, CNNTissue]) \
        for j, k in zip([2, 0.1], ['class', 'dirichlet']) for K in [5,10]])
        p.close()

    print(time.perf_counter()-start, "seconds")