import os
from Sims import *
from Settings import *
from Util import *
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

class FL_Proc:
    def __init__(self, configs, model):
        self.Name = configs["name"]
        self.ModelName = configs["mname"]
        self.NClients = configs["nclients"]
        self.PClients = configs["pclients"]
        self.IsIID = configs["isIID"]
        self.Alpha = configs["alpha"]
        self.Aug = configs["aug"]
        self.MaxIter = configs["iters"]
        self.LogStep = configs["logstep"]
        self.LR = configs["learning_rate"]
        self.Normal = configs["normal"]
        self.Optmzer = configs["optimizer"]
        self.FixLR = configs["fixlr"]
        self.WDecay = configs["wdecay"]
        self.DShuffle = configs["data_shuffle"]
        self.BatchSize = configs["batch_size"]
        self.Epoch = configs["epoch"]
        self.GlobalLR = configs["global_lr"]
        self.FixFedAvg = configs['fix_fedavg']
        self.SaveModel = configs['save_model']
        self.UseCP = configs["critical"]
        self.FIM = configs["fim"]
        self.CThresh = configs["CThresh"]
        self.GModel = model
        self.Server = None
        self.Clients = {}
        self.ClientLoaders = None
        self.TrainLoader = None
        self.TestLoader = None
        self.logpath = None
        self.updateIDs = []
        for i in range(self.PClients):
            self.updateIDs.append(i)
        self.Detection = CPCheck(self.NClients, self.PClients, threshold=self.CThresh)
        self.Selection = RandomGet(self.NClients)

    def get_train_datas(self):
        self.ClientLoaders, self.TrainLoader, self.TestLoader, Stat = get_loaders(self.Name, self.NClients, self.IsIID,self.Alpha, self.Aug, False, False,self.Normal, self.DShuffle, self.BatchSize)

    def logging(self):
        teloss, teaccu = self.Server.evaluate(self.TestLoader)

    def main(self):
        self.get_train_datas()
        self.Server = Server_Sim(self.TrainLoader, self.GModel, self.LR, self.WDecay, self.FixLR)
        for c in range(self.NClients):
            self.Clients[c] = Client_Sim(self.ClientLoaders[c], self.GModel, self.LR, self.WDecay, self.Epoch,self.FixLR, self.Optmzer)
            self.Selection.register_client(c, self.Clients[c].DLen)
        NumPartens = self.PClients
        DetStep = 1
        self.logging()
        CLP = 1

        for It in range(self.MaxIter):
            print(It + 1, "-th Round")
            if (It + 1) % DetStep == 0:
                CLP = 0
                GetNorms = []
                for ky in self.updateIDs:
                    GetNorms.append(self.Clients[ky].gradnorm)
                if self.UseCP:
                    self.Detection.recvInfo(GetNorms)
                    NumPartens,CLP = self.Detection.WinCheck(len(self.updateIDs))
            updateIDs = self.Selection.select_participant(NumPartens,CLP,self.UseCP)
            LrNow = self.Server.getLR()
            GlobalParms = self.Server.getParas()
            TransLens = []
            TransParas = []
            TransVecs = []
            for ky in updateIDs:
                if self.GlobalLR:
                    self.Clients[ky].updateLR(LrNow)
                self.Clients[ky].updateParas(GlobalParms)
                Nvec = self.Clients[ky].selftrain()
                ParasNow = self.Clients[ky].getParas()
                LenNow = self.Clients[ky].DLen
                TransLens.append(LenNow)
                TransParas.append(ParasNow)
                TransVecs.append(Nvec)
            TauEffs = []
            SLen = np.sum(TransLens)
            for k in range(len(TransLens)):
                TauEffs.append(TransLens[k] / SLen * TransVecs[k])
            TauEff = np.sum(TauEffs)
            for k in range(len(TransLens)):
                GPara = TransParas[k]
                GLen = TransLens[k] / SLen
                GNvec = TauEff / TransVecs[k]
                self.Server.recvInfo(GPara, GLen, GNvec)
            self.Server.aggParas()
            
            if self.Optmzer == "VRL":
                GlobalParms = self.Server.getParas()
                for ky in updateIDs:
                    self.Clients[ky].updateParas(GlobalParms)
                    LSteps = self.Clients[ky].local_steps
                    self.Clients[ky].optimizer.update_delta(LSteps)
            self.updateIDs = updateIDs
            if (It + 1) % self.LogStep == 0:
                self.logging()


if __name__ == '__main__':
    Dataname = "fmnist"
    Type = "alex"
    Model = load_Model(Type, Dataname)

    Configs = {}
    Configs["critical"] = True
    Configs["normal"] = True
    Configs["fixlr"] = False
    Configs["global_lr"] = True
    Configs["aug"] = False
    Configs["fix_fedavg"] = True
    Configs["data_shuffle"] = True
    Configs["save_model"] = False
    Configs["fim"] = True
    Configs['logstep'] = 2
    Configs['name'] = Dataname
    Configs["mname"] = Type
    Configs['nclients'] = 128
    Configs['pclients'] = 16
    Configs['isIID'] = False
    Configs["alpha"] = 10.0
    Configs["CThresh"] = 0.01
    Configs["epoch"] = 2
    Configs["optimizer"] = "SGD"
    Configs["learning_rate"] = 0.01
    Configs["wdecay"] = 1e-5
    Configs["batch_size"] = 8
    Configs["iters"] = 200

    FLSim = FL_Proc(Configs, Model)
    FLSim.main()




