from models.Net import Net
from numpy.core.numeric import zeros_like
from optimizers.SCAFFOLD import SCAFFOLD
from optimizers.ditto import Ditto
from types import MethodWrapperType
from torch._C import device
from torch.optim import optimizer
from torch.utils.data import dataset
from losses.cflLoss import CFLLoss
from numpy.core.fromnumeric import mean, partition
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.serialization import save
import torchvision
import torchvision.transforms as transforms
import torch.utils.data.dataloader as dataloader
import pdb
import os
from torch.distributions.dirichlet import Dirichlet
from copy import deepcopy
import torch.nn.functional as F
from torchvision import datasets, models
from torch.autograd import Variable, grad
import numpy as np
import random
import matplotlib.pyplot as plt
import argparse
import SplitDataset
from SplitDataset import SplitDataset
import optimizers.fedProx
from optimizers.fedProx import FedProx
import optimizers.fedNova
from optimizers.fedNova import FedNova
from optimizers.fcl import FCLR
from comm_helpers import combine_reduction, fedAvg_communicate, fedNova_communicate
from pyhessian import hessian # Hessian computation
from hessian_tools import combine_hessian, get_SI_omega, get_diag_fisher, get_diag_hessian, copy_hessian, get_grads
import os
from pathlib import Path
from samplers.MCMCSampler import MCMCSampler, MCMCSet
from optimizers.Mime import Mime

def schedule_lr(t, args):
    # if t >= 0.75 * args.round_num:
    #     return args.lr * 0.01
    # if t >= 0.5 * args.round_num:
    #     return args.lr * 0.1
    return args.lr

def schedule_alpha(t, args):
    return args.halpha
    # return args.halpha * (0.96 ** t)


def choose_optimizer(model, args, t, H_agg, G_agg):
    lr = schedule_lr(t, args)
    if args.optimizer == 'FedAvg'  or args.optimizer == 'Replay':
        return FedProx(model.parameters(), lr = lr, mu=0)
    elif args.optimizer == 'FedProx':
        return FedProx(model.parameters(), lr = lr, mu=args.mu)
    elif args.optimizer == 'FedNova':
        return FedNova(model.parameters(), lr = lr, mu=args.mu, named_params=model.named_parameters())
    elif args.optimizer == 'CFL-R':
        return FCLR(model.parameters(), H_agg, G_agg, alpha = schedule_alpha(t, args), lr = lr, mu=args.mu)
    elif args.optimizer == 'SCAFFOLD':
        return SCAFFOLD(model.parameters(), lr = lr, mu=args.mu)
    else:
        raise ValueError('Invalid optimizer: {}'.format(args.optimizer))

def get_information_loss(model, data, g_data, source_data, device, args, diag_hessian, P, t, W, Gs):
    model1, model2 = models.resnet.resnet18(), models.resnet.resnet18()
    num_ftrs = model.fc.in_features
    model1.fc, model2.fc = nn.Linear(num_ftrs, 10), nn.Linear(num_ftrs, 10)
    model1.load_state_dict(deepcopy(model.state_dict()))
    model2.load_state_dict(deepcopy(model.state_dict()))
    args1, args2 = deepcopy(args), deepcopy(args)

    args1.localepoch = 1
    model1, _, _, _ = local_epoch(model1.to(device), data, g_data, device, args1, diag_hessian, P, t, W, Gs)

    args2.localepoch = 1
    args2.optimizer = 'FedAvg'
    model2, _, _, _ = local_epoch(model2.to(device), source_data, None, device, args2, diag_hessian, P, t, W, Gs)

    information_loss = 0
    l1 = 0 if g_data == None else len(g_data)
    for p1, p2, p in zip(model1.parameters(), model2.parameters(), model.parameters()):
        information_loss += torch.norm((p1 - p) / (len(data) + l1) - (p2 - p) / len(source_data))
    
    return information_loss

def get_prob(model, P):
    return sum([torch.norm(p - o) for p, o in zip(model.parameters(), P)])

def choose_hessians(hessians, Ps, Gs, model):
    S = len(hessians)
    alpha = 0.1
    prob = torch.tensor([1 / (alpha + get_prob(model, Ps[i])) for i in range(S)])
    H_agg = [torch.zeros_like(p.data) for p in model.parameters()]
    G_agg = [torch.zeros_like(p.data) for p in model.parameters()]
    res = S % 10
    for i in range(res):
        prob[-i] = 0
    if S < 10:
        return H_agg, G_agg
    chosen_items = torch.multinomial(prob, min(50, S - res))
    sum_P = sum([prob[i] for i in chosen_items])
    # sum_P = 1.0
    for index in chosen_items:
        H_agg = [agg + prob[index] * h / sum_P for agg, h in zip(H_agg, hessians[index])]
        G_agg = [agg + prob[index] * (g - torch.mul(h, p)) / sum_P for agg, g, h, p in zip(G_agg, Gs[index], hessians[index], Ps[index])]
    return H_agg, G_agg

def ditto_epoch(global_model, local_model, data, device, args):
    global_model.train()
    global_model.to(device)
    local_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Ditto(global_model.parameters(), local_model.parameters(), lr=args.lr, mu=args.mu, lamda=args.lamda)
    # optimizer = FedProx(global_model.parameters(), lr=args.lr)
    for k in range(args.localepoch):
        for (X, Y) in data:
            X, Y = X.to(device), Y.to(device)
            X, Y = Variable(X), Variable(Y)
            global_model.zero_grad()
            output = global_model(X)
            # print(args.useGs)
            loss = criterion(output, Y)
            loss.backward()
            optimizer.step(local_model.parameters())
    
    local_model.train()
    global_model.eval()
    optimizer = Ditto(local_model.parameters(), global_model.parameters(), lr=args.lr, mu=args.mu, lamda=args.lamda)
    for k in range(args.localepoch):
        for (X, Y) in data:
            X, Y = X.to(device), Y.to(device)
            X, Y = Variable(X), Variable(Y)
            global_model.zero_grad()
            output = local_model(X)
            # print(args.useGs)
            loss = criterion(output, Y)
            loss.backward()
            optimizer.step(global_model.parameters())
    
    return global_model, local_model

def mime_epoch(model, data, device, args, s):
    model.train()
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    if s == None:
        s = [None for p in model.parameters()]
    optimizer = Mime(model.parameters(), lr=args.lr, mu=args.mu)
    for k in range(args.localepoch):
        for (X, Y) in data:
            X, Y = X.to(device), Y.to(device)
            X, Y = Variable(X), Variable(Y)
            model.zero_grad()
            output = model(X)
            loss = criterion(output, Y)
            loss.backward()
            optimizer.step(s)
    Gs = get_grads(deepcopy(model), data, device, criterion)
    return model, Gs

def scaffold_epoch(model, data, device, args, C, local_C):
    model.train()
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    old_params = deepcopy(list(model.parameters()))
    optimizer = SCAFFOLD(model.parameters(), lr=args.lr, mu=args.mu)
    for k in range(args.localepoch):
        for (X, Y) in data:
            X, Y = X.to(device), Y.to(device)
            X, Y = Variable(X), Variable(Y)
            model.zero_grad()
            output = model(X)
            loss = criterion(output, Y)
            loss.backward()
            optimizer.step(C, local_C)
    c_plus = [c_i - c + 1 / (args.localepoch * args.lr) * (x - y) for c_i, c, x, y in zip(local_C, C, old_params, model.parameters())]

    return model, c_plus

    








def local_epoch(model, data, g_data, device, args, hessians, Ps, t, W, Gs):
    model.train()
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    H_agg, G_agg = None, None
    if args.optimizer == 'CFL-R':
        H_agg, G_agg = choose_hessians(hessians, Ps, Gs, model)
    optimizer = choose_optimizer(model, args, t, H_agg, G_agg)

    P_old = [p.detach().clone() for p in model.parameters()]

    for k in range(args.localepoch):
        for (X, Y) in data:
            X, Y = X.to(device), Y.to(device)
            X, Y = Variable(X), Variable(Y)
            model.zero_grad()
            output = model(X)
            # print(args.useGs)
            loss = criterion(output, Y)
            loss.backward()
            optimizer.step()
        for i, (p, p_old) in enumerate(zip(model.parameters(), P_old)):
            if p.requires_grad:
                W[i].add_(-p.grad*(p.detach()-p_old))
            P_old[i] = p.detach().clone()
        if g_data != None:
            for (X, Y) in g_data:
                X, Y = X.to(device), Y.to(device)
                X, Y = Variable(X), Variable(Y)
                model.zero_grad()
                output = model(X)
                loss = criterion(output, Y)
                loss.backward()
                optimizer.step()
    # for (X, Y) in data:
    #     X, Y = X.to(device), Y.to(device)
    #     X, Y = Variable(X), Variable(Y)
    #     output = model(X)

    
    # approximate hessian
    if args.optimizer != 'CFL-R':
        return model, None, len(data), None
    if args.hessian == 'DiagH':
        diagH, Gs = get_diag_hessian(deepcopy(model), data, device, criterion)
    elif args.hessian == 'Fisher':
        diagH = get_diag_fisher(deepcopy(model), data, device, criterion)
        Gs = get_grads(deepcopy(model), data, device, criterion)
    elif args.hessian == 'SI':
        diagH = get_SI_omega(deepcopy(model), W, args.epsilon, P_old)
        Gs = get_grads(deepcopy(model), data, device, criterion)
    # print(diagH[-1])
    return model, diagH, len(data), Gs




class FedTrain():

    def __init__(self, args) -> None:
        # useful parameters
        self.N= args.client_num
        self.T = args.round_num
        self.S = args.split_num
        self.global_lr = args.globallr
        # generate data
        new_partitioner = SplitDataset(args)
        self.total_train_sets = new_partitioner.trainset
        self.index = [0 for _ in range(self.N)]
        # self.train_sets = new_partitioner.round_sets
        self.train_sets = new_partitioner.combine_round_sets
        self.test_loader = new_partitioner.test_loader
        self.original_sets = new_partitioner.total_set
        # initial global model and local models
        self.global_model, self.models = self.init_models()

        self.chosen_clients = [[False for _ in range(self.S)] for _ in range(self.N)]

    def init_models(self):
        # load finetune models
        local_models = [models.resnet18(pretrained=True) for i in range(self.N)]
        global_model = models.resnet18(pretrained=True)
        # local_models = [models.resnet34(pretrained=True) for i in range(self.N)]
        # global_model = models.resnet34(pretrained=True)
        # change output layer
        num_ftrs = global_model.fc.in_features
        global_model.fc = nn.Linear(num_ftrs, 10)
        for i in range(self.N):
            local_models[i].fc = nn.Linear(num_ftrs, 10)
        # local_models = [Net() for i in range(self.N)]
        # global_model = Net()

        return global_model, local_models
    
    def set_weights(self):
        # transmit parameters to local models

        global_dict = self.global_model.state_dict()

        for i in range(len(self.models)):
            self.models[i].load_state_dict(deepcopy(global_dict))
    
    def combine_weights(self, args, Ls, models):
        # combine local parameters for global model
        if args.optimizer == 'FedAvg' or args.optimizer == 'FedProx' or args.optimizer == 'Replay' or args.optimizer == 'Ditto' or args.optimizer == 'SCAFFOLD' or args.optimizer == 'Mime':
            self.global_model = fedAvg_communicate(self.global_model, models, args, Ls)
        elif args.optimizer == 'FedNova':
            self.global_model = fedNova_communicate(self.global_model, models, args, Ls)
        elif args.optimizer == 'CFL-R':
            self.global_model = fedAvg_communicate(self.global_model, models, args, Ls)
        else:
            raise ValueError('Invalid optimizer: {}'.format(args.optimizer))

    def get_dataloader_overlap(self, i, length, args, before_dataset):
        data_set = self.train_sets[i].subset([j % len(self.train_sets[i]) for j in range(self.index[i], self.index[i] + length)])
        # self.index[i] = (self.index[i] + length) % len(self.train_sets[i])
        self.index[i] = (self.index[i] + args.step) % len(self.train_sets[i])
        subloader = None

        if args.optimizer == 'Replay' and len(before_dataset[i]) > 0:
            subset = before_dataset[0].weighted_core_set(args.coreSetSize)
            for j in range(1, self.N):
                if len(before_dataset[j]) > 0:
                    subset = subset.combine(before_dataset[j].weighted_core_set(args.coreSetSize // self.N))
            subloader = dataloader.DataLoader(subset, args.bs, True, drop_last=True)
            # data_set = subset.combine(data_set)

        data_loader = dataloader.DataLoader(data_set, args.bs, True, drop_last=True)

        return data_loader, subloader, data_set




    def get_dataloader(self, i, index, args, before_dataset):
        data_set = self.train_sets[index]
        subloader = None
        # print(len(data_set) // args.bs)

        if args.optimizer == 'Replay' and len(before_dataset[i]) > 0:
            subset = before_dataset[0].weighted_core_set(args.coreSetSize)
            for j in range(1, self.N):
                if len(before_dataset[j]) > 0:
                    subset = subset.combine(before_dataset[j].weighted_core_set(args.coreSetSize // self.N))
            subloader = dataloader.DataLoader(subset, args.bs, True, drop_last=True)
            # data_set = subset.combine(data_set)

        data_loader = dataloader.DataLoader(data_set, args.bs, True, drop_last=True)
        # print(len(data_loader))

        # ideal_dataloader = dataloader.DataLoader(self.original_sets.naive_core_set(len(data_set)), args.bs, True, drop_last=True)


        return data_loader, subloader

    def mean_feature_vectors(self, model, data, label, device, args):
        data = [d for d in data if d[1] == label]
        data_loader = dataloader.DataLoader(data, 32, True, drop_last=True)
        model.to(device)
        count = 0
        mean_vector = None
        for (X, Y) in data_loader:
            count += 1
            X= X.to(device)
            out = torch.mean(model(X), dim=0)
            if mean_vector == None:
                mean_vector = out
            else:
                mean_vector = mean_vector + out
        if mean_vector == None:
            return None
        return mean_vector.clone().detach() / count



            

    def init_variables(self, args):

        # hessians = [torch.zeros_like(p.data).to(device) for p in self.global_model.parameters()]
        hessians = []
        Ps = []
        Gs = []
        if args.coreSetMethod == 'Naive' or args.coreSetMethod == 'iCaRL':
            before_dataset = [SplitDataset.empty_set(self.total_train_sets) for _ in range(self.N)]
        elif args.coreSetMethod == 'MCMC':
            before_dataset = [MCMCSet() for _ in range(self.N)]

        return hessians, Ps, before_dataset, Gs

    def combine_variables(self, new_before_set, i, data_set, args):
        if args.coreSetMethod == 'Naive':
            new_before_set[i] = new_before_set[i].combine(data_set.naive_core_set(args.coreSetSize))
        # elif args.coreSetMethod == 'iCaRL':
        #     for label, partition in self.train_sets[index].labels.items():
        #         new_before_set[i] = new_before_set[i].combine(self.train_sets[index].icarl_core_set(int(args.coreSetSize * partition), dataloader.DataLoader(self.train_sets[index]), 'cuda', label))
        # elif args.coreSetMethod == 'MCMC':
        #     # print(self.train_sets[index].labels)
        #     for label, partition in self.train_sets[index].labels.items():
        #         if int(args.coreSetSize * partition) > 10:
        #             mean_vector = self.mean_feature_vectors(deepcopy(self.models[i]), self.train_sets[index], label, 'cuda', args)
        #             if mean_vector != None:
        #                 sampler = MCMCSampler(deepcopy(self.models[i]), self.train_sets[0][0][0].shape, int(args.coreSetSize * partition), label, mean_vector)
        #                 new_samples = MCMCSet(sampler.sample_new_exmps())
        #                 # new_samples.abstract(10)
        #                 new_before_set[i] = new_before_set[i].combine(new_samples)
        #                 # print(len(new_before_set[i]))
        if len(new_before_set[i]) > args.coreSetSize * 30:
                new_before_set[i] = new_before_set[i].naive_core_set(args.coreSetSize * 30)

        
        return new_before_set

    def copy_variables(self, hessians, new_hessians, Ps, new_Ps, L, alpha):
        hessians = copy_hessian(hessians, new_hessians, L, alpha=alpha)
        Ps = copy_hessian(Ps, new_Ps, L, alpha=alpha)
        return hessians, Ps

    def update_C(self, C, local_Cs, c_plus_s, Ls, sum_L, chosen_clients):
        for i in range(self.N):
            C = [c + (c_plus - c_i) * L / sum_L for c, c_plus, c_i, L in zip(C, c_plus_s[i], local_Cs[i * self.S + chosen_clients[i]], Ls)]
        return C, c_plus_s

    def update_S(self, S, Gs):
        if S != None:
            new_S = [0.9 * s for s in S]
            for i in range(self.N):
                new_S = [s + 0.1 * g / self.N for s, g in zip(new_S, Gs[i])]
        else:
            new_S =  [torch.zeros_like(p.data).to('cuda') for p in self.global_model.parameters()]
            for i in range(self.N):
                new_S = [s + g / self.N for s, g in zip(new_S, Gs[i])]
        return new_S




    
    def run(self, args):
        # init some variables
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # device = 'cpu'
        hessians, Ps, before_dataset, Gs = self.init_variables(args)
        self.global_model.to(device)
        C = [torch.zeros_like(p.data).to(device) for p in self.global_model.parameters()]
        S = None
        saves = []
        # local_Cs = [deepcopy(C) for _ in range(self.N * self.S)]
        # train procedure
        for t in range(self.T):
            W = [torch.zeros_like(p.data).to(device) for p in self.global_model.parameters()]
            self.set_weights()
            Ls = []
            # chosen_clients = list(np.random.randint(0, self.S, self.N))
            information_loss = 0
            mime_Gs = []

            for i in range(self.N):
                # print(len(before_dataset[i]))
                # data_loader, subloader = self.get_dataloader(i, i * self.S + chosen_clients[i], args, before_dataset)
                data_loader, subloader, data_set = self.get_dataloader_overlap(i, 285, args, before_dataset)

                if args.optimizer == 'Mime':
                    self.models[i], mime_G = mime_epoch(self.models[i], data_loader, device, args, S)
                    # Ls.append(len(self.train_sets[i * self.S + chosen_clients[i]]))
                    Ls.append(len(data_loader))
                    mime_Gs.append(mime_G)
                    # self.chosen_clients[i][chosen_clients[i]] = True
                else:
                    self.models[i], new_hessian, L, new_G = local_epoch(self.models[i].to(device), data_loader, subloader, device, args, hessians, Ps, t, W, Gs)

                    mime_Gs.append(new_G)

                    # Ls.append(len(self.train_sets[i * self.S + chosen_clients[i]]))
                    Ls.append(len(data_loader))
                    before_dataset = self.combine_variables(before_dataset, i, data_set, args)

                    hessians.append(new_hessian)
                    Ps.append([p.data.clone().detach() for p in self.models[i].parameters()])
                    Gs.append(new_G)
                while len(hessians) > 40:
                    hessians.pop()
                    Ps.pop()
                    Gs.pop()
            if args.optimizer == 'Mime':
                self.combine_weights(args, Ls, self.models)
                S = self.update_S(S, mime_Gs)
            else:
                self.combine_weights(args, Ls, self.models)
                
            acc, loss = self.test(self.test_loader, device)
            train_acc, train_loss = 0, 0
            # train_acc, train_loss = self.train_test(device)
            save_string = 'acc: {}, loss: {}, t_acc: {}, t_loss: {} on round {}, information loss: {}'.format(acc, loss, train_acc, train_loss, t, information_loss)
            print(save_string)
            saves.append(save_string)
        return saves

    
    def train_test(self, device):
        corrects = []
        losses = []
        self.global_model.eval()
        for i in range(self.N):
            for j in range(self.S):
                if not self.chosen_clients[i][j]:
                    continue
                data = dataloader.DataLoader(self.train_sets[i * self.S + j], args.bs, True, drop_last=True)
                loss = 0
                total = 0
                correct = 0
                for (X, Y) in data:
                    X, Y = X.to(device), Y.to(device)
                    X, Y = Variable(X), Variable(Y)

                    out = self.global_model(X)
                    loss += F.cross_entropy(out, Y) * Y.size(0)
                    _, predicted = torch.max(out.data, 1)
                    total += Y.size(0)
                    correct += (predicted == Y).sum().item()
                    X.to('cpu')
                    Y.to('cpu')
                corrects.append(correct / total)
                losses.append(loss / total)
                torch.cuda.empty_cache()
        return sum(corrects) / len(corrects), sum(losses) / len(losses)



    def test(self, data, device):

        loss = 0
        total = 0
        correct = 0
        self.global_model.to(device)

        # self.global_model.train()
        # for (X, Y) in data:
        #     X, Y = X.to(device), Y.to(device)
        #     X, Y = Variable(X), Variable(Y)
        #     out = self.global_model(X)

        self.global_model.eval()
        for (X, Y) in data:
            X, Y = X.to(device), Y.to(device)
            X, Y = Variable(X), Variable(Y)

            out = self.global_model(X)
            loss += F.cross_entropy(out, Y) * Y.size(0)
            _, predicted = torch.max(out.data, 1)
            total += Y.size(0)
            correct += (predicted == Y).sum().item()
        return correct / total, loss / total

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CIFAR-10 baseline')
    parser.add_argument('--client_num','-cN', 
                    default=7, 
                    type=int, 
                    help='the number of clients')
    parser.add_argument('--round_num','-rN', 
                    default=500, 
                    type=int, 
                    help='the number of communication rounds')
    parser.add_argument('--split_num','-sN', 
                    default=30, 
                    type=int, 
                    help='the number of subsets of a client')
    parser.add_argument('--round_drift','-rd', 
                    default=0.2, 
                    type=float, 
                    help='round drift') 
    parser.add_argument('--client_drift','-cd', 
                    default=0.1, 
                    type=float, 
                    help='client drift')
    parser.add_argument('--lr', 
                    default=0.01, 
                    type=float, 
                    help='client learning rate')
    parser.add_argument('--bs', 
                    default=10, 
                    type=int, 
                    help='batch size on each worker/client')
    parser.add_argument('--NIID',
                    default=True,
                    action='store_true',
                    help='whether the dataset is non-iid or not')
    parser.add_argument('--datapath',
                    default='./data/',
                    type=str,
                    help='directory to load data')
    parser.add_argument('--savepath',
                    default='./results/checkpoints/',
                    type=str,
                    help='directory to save results')
    parser.add_argument('--optimizer',
                    default='FedAvg',
                    type=str,
                    help='type of optimizer')
    parser.add_argument('--repeat',
                    default=1,
                    type=int,
                    help='rounds of same local datasets repeat')
    parser.add_argument('--globallr',
                    default=1.0,
                    type=float,
                    help='global learning rate')
    parser.add_argument('--localepoch',
                    default=5,
                    type=int,
                    help='number of local epochs')
    parser.add_argument('--ratio',
                    default=1.0,
                    type=float,
                    help='ratio of chosen clients in each round')
    parser.add_argument('--mu',
                    default=0,
                    type=float,
                    help='weight of proximal term')
    parser.add_argument('--halpha',
                    default=1,
                    type=float,
                    help='weight of hessian term')
    parser.add_argument('--splittype',
                    default='R',
                    type=str,
                    help='type of splitting data')
    parser.add_argument('--gamma',
                    default=0.1,
                    type=float,
                    help='type of splitting data')
    parser.add_argument('--lamda',
                    default=0.1,
                    type=float,
                    help='ratio of regularization term in ditto')
    parser.add_argument('--useGs',
                    default=0,
                    type=int,
                    help='if use gradient of previous rounds')
    parser.add_argument('--hessian',
                    default='DiagH',
                    type=str,
                    help='the type of hessian')
    parser.add_argument('--epsilon',
                    default=0.0001,
                    type=float,
                    help='epsilon in SI')
    parser.add_argument('--clip',
                    default=1,
                    type=int,
                    help='if clip the gradient of clients'
    )
    parser.add_argument('--coreSetSize',
                    default=10,
                    type=int,
                    help='the size of coresets of each round'
    )
    parser.add_argument('--coreSetMethod',
                    default='Naive',
                    type=str,
                    help='method to generate core set'
    )
    parser.add_argument('--seed',
                    default=1234,
                    type=int,
                    help='random seed'
    )
    parser.add_argument('--step',
                    default=30,
                    type=int,
                    help='step of slide window'
    )
    args = parser.parse_args()

    torch.cuda.set_device(1)
    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # acces = []
    # for i in range(15):
    # i = 15
    # args.coreSetSize = i * 10
    # fedTrain = FedTrain(args)
    # acces.append(fedTrain.run(args))
    # print(acces)

    fedTrain = FedTrain(args)
    saves = fedTrain.run(args)
    torch.save(saves, './data/{}-{}-{}-{}-{}.pt'.format(args.optimizer, args.mu, args.round_drift, seed, args.step))
