import numpy as np
from Settings import *
from Util import *
from Comp_FIM.object import PMatKFAC, PMatEKFAC, PMatDiag, PMatDense
from Comp_FIM.metrics import FIM, FIM_MonteCarlo
from Comp_FIM.object.vector import random_pvector, PVector

class Client_Sim:
    def __init__(self, Loader, Model, Lr, wdecay, epoch=1, fixlr=False, optzer="SGD"):
        self.TrainData = cp.deepcopy(Loader)
        self.Model = cp.deepcopy(Model)
        self.Optzer = optzer
        self.Wdecay = wdecay
        self.Epoch = epoch
        self.Mu = 0.001
        self.Round = 0
        self.LR = Lr
        self.decay_step = 10
        self.decay_rate = 0.9
        self.local_steps = 1
        self.optimizer = VRL(self.Model.parameters(), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay, vrl=True, local=True)
        self.loss_fn = nn.CrossEntropyLoss()
        self.FixLR = fixlr
        self.gradnorm = 0
        self.trainloss = 0

    def reload_data(self, loader):
        self.TrainData = cp.deepcopy(loader)

    def getParas(self):
        GParas = cp.deepcopy(self.Model.state_dict())
        return GParas

    def updateParas(self, Paras):
        self.Model.load_state_dict(Paras)

    def updateLR(self, lr):
        self.LR = lr
        self.decay_rate = 1

    def getLR(self):
        return self.LR

    def selftrain(self):
        self.Round += 1
        if self.Round % self.decay_step == 0:
            self.LR *= self.decay_rate
        optimizer = None
        if self.Optzer == "SGD":
            optimizer = torch.optim.SGD(self.Model.parameters(), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay)
        if self.Optzer == "FedProx":
            optimizer = FedProx(self.Model.parameters(), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay, mu = self.Mu)
        if self.Optzer == "FedNova":
            optimizer = FedNova(self.Model.parameters(), lr=self.LR, momentum=0.9, weight_decay=self.Wdecay)
        self.optimizer.param_groups[0]['lr'] = self.LR
        if self.Optzer == "VRL":
            optimizer = self.optimizer
        self.gradnorm = 0
        self.trainloss = 0
        SLoss = []
        GNorm = []
        self.Model.train()
        Local_Steps = 0
        for r in range(self.Epoch):
            sum_loss = 0
            grad_norm = 0
            C = 0
            for batch_id, (inputs, targets) in enumerate(self.TrainData):
                C = C + 1
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.Model(inputs)
                optimizer.zero_grad()
                if self.Optzer == "VRL":
                    self.optimizer.zero_grad()
                loss = self.loss_fn(outputs, targets)
                loss.backward()
                if self.Optzer == "VRL":
                    self.optimizer.step()
                else:
                    optimizer.step()
                temp_norm = 0
                for parms in self.Model.parameters():
                    gnorm = parms.grad.detach().data.norm(2)
                    temp_norm = temp_norm + (gnorm.item()) ** 2
                if grad_norm == 0:
                    grad_norm = temp_norm
                else:
                    grad_norm = grad_norm + temp_norm
            SLoss.append(sum_loss / C)
            GNorm.append(grad_norm)
            Local_Steps = C
        self.trainloss = np.mean(SLoss)
        Lrnow = self.getLR()
        self.gradnorm = np.mean(GNorm) * Lrnow
        self.local_steps = Local_Steps * self.Epoch
        if self.Optzer == "VRL":
            self.optimizer.update_params()
        NVec = 1
        if self.Optzer == "FedNova":
            NVec = optimizer.local_normalizing_vec
        return NVec

    def evaluate(self, loader=None, max_samples=100000):
        self.Model.eval()
        loss, correct, samples, iters = 0, 0, 0, 0
        loss_fn = nn.CrossEntropyLoss()
        if loader == None:
            loader = self.TrainData
        with torch.no_grad():
            for i, (x, y) in enumerate(loader):
                x, y = x.to(device), y.to(device)
                y_ = self.Model(x)
                _, preds = torch.max(y_.data, 1)
                correct += (preds == y).sum().item()
                loss += loss_fn(y_, y).item()
                samples += y_.shape[0]
                iters += 1
                if samples >= max_samples:
                    break
        return correct / samples, loss / iters
        
    def fim(self,loader=None):
        if loader == None:
            loader = cp.deepcopy(self.TrainData)
        self.Model.eval()
        Ts = []
        K = 10000
        for i, (x,y) in enumerate(loader):
                x, y = list(x.cpu().detach().numpy()), list(y.cpu().detach().numpy())
                for j in range(len(x)):
                    Ts.append([x[j],y[j]])
                if len(Ts) >= K:
                    break
        TLoader = torch.utils.data.DataLoader(dataset=Ts, batch_size=500, shuffle=False)
        F_Diag = FIM(
            model=self.Model,
            loader=TLoader,
            representation=PMatDiag,
            n_output=10,
            variant="classif_logits",
            device="cuda"
        )
        Tr = F_Diag.trace().item()
        return Tr

class Server_Sim:
    def __init__(self, Loader, Model, Lr, wdecay=0, Fixlr=False):
        self.TrainData = cp.deepcopy(Loader)
        self.Gamma = load_gamma()
        self.DLen = 0
        for batch_id, (inputs, targets) in enumerate(self.TrainData):
            inputs, targets = inputs.to(device), targets.to(device)
            self.DLen += len(inputs)
        self.Model = cp.deepcopy(Model)
        self.optimizer = torch.optim.SGD(self.Model.parameters(), lr=Lr, momentum=0.9, weight_decay=wdecay)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=self.Gamma)
        self.loss_fn = nn.CrossEntropyLoss()
        self.FixLr = Fixlr
        self.RecvParas = []
        self.RecvLens = []
        self.RecvScale = []
        self.RecvAs = []
        self.LStep = 0
        self.CStep = 0

    def reload_data(self, loader):
        self.TestData = cp.deepcopy(loader)

    def getParas(self):
        GParas = cp.deepcopy(self.Model.state_dict())
        return GParas

    def getLR(self):
        LR = self.optimizer.state_dict()['param_groups'][0]['lr']                
        return LR

    def updateParas(self, Paras):
        self.Model.load_state_dict(Paras)

    def avgParas(self, Paras, Ps, Scale):
        Res = cp.deepcopy(Paras[0])
        Lens = []
        for i in range(len(Ps)):
            Lens.append(Ps[i] * Scale[i])
        Sum = np.sum(Lens)
        for ky in Res.keys():
            Mparas = 0
            for i in range(len(Paras)):
                Pi = Lens[i] / Sum
                Mparas += Paras[i][ky] * Pi
            Res[ky] = Mparas
        return Res

    def aggParas(self):
        if len(self.RecvLens) < 1:
            return 0
        GParas = self.avgParas(self.RecvParas, self.RecvLens, self.RecvScale)
        self.updateParas(GParas)
        self.RecvParas = []
        self.RecvLens = []
        self.RecvScale = []
        if self.FixLr == False:
            self.optimizer.step()
            self.scheduler.step()

    def recvInfo(self, Para, Len, Scale):
        self.RecvParas.append(Para)
        self.RecvLens.append(Len)
        self.RecvScale.append(Scale)

    def evaluate(self, loader=None, max_samples=100000):
        self.Model.eval()
        loss, correct, samples, iters = 0, 0, 0, 0
        if loader == None:
            loader = self.TrainData
        with torch.no_grad():
            for i, (x, y) in enumerate(loader):
                x, y = x.to(device), y.to(device)
                y_ = self.Model(x)
                _, preds = torch.max(y_.data, 1)
                loss += self.loss_fn(y_, y).item()
                correct += (preds == y).sum().item()
                samples += y_.shape[0]
                iters += 1
                if samples >= max_samples:
                    break
        return loss / iters, correct / samples

    def saveModel(self, Path):
        torch.save(self.Model, Path)






