import torch


class LocalModel(torch.nn.Module):
    def __init__(self):
        super(LocalModel, self).__init__()
        self.l1 = torch.nn.Linear(1, 2, bias=False)
        self.l2 = torch.nn.Linear(2, 1, bias=False)

    def forward(self, x):
        x = self.l1(x)
        return self.l2(x)

    def trainable_parameters(self):
        for p in self.parameters():
            print("local model = " + str(p))
        return [p for p in self.parameters() if p.requires_grad]


class GlobalModel(torch.nn.Module):
    def __init__(self):
        super(GlobalModel, self).__init__()
        self.l1 = torch.nn.Linear(1, 2, bias=False)
        self.l2 = torch.nn.Linear(2, 1, bias=False)

    def forward(self, x):
        x = self.l1(x)
        return self.l2(x)

    def trainable_parameters(self):
        for p in self.parameters():
            print("global model = " + str(p))
        return [p for p in self.parameters() if p.requires_grad]


def model_dist_norm_var(local_model, global_model):
    size = 0
    for layer in local_model.trainable_parameters():
        size += layer.view(-1).shape[0]
    sum_var = torch.FloatTensor(size).fill_(0)
    size = 0
    for (p, g_p) in zip(local_model.trainable_parameters(),
                        global_model.trainable_parameters()):

        sum_var[size:size + p.view(-1).shape[0]] = ((p - g_p)).view(-1)
        size += p.view(-1).shape[0]

    return torch.linalg.norm(sum_var)**2


def model_dist_norm_var2(local_model, global_model):
    l2_norm_distance = torch.tensor(0.0, dtype=torch.float32)
    for (p, g_p) in zip(local_model.trainable_parameters(),
                        global_model.trainable_parameters()):
        p.data -= g_p.data
        l2_norm_distance += torch.linalg.norm(p) ** 2
    return l2_norm_distance


local_model = LocalModel()
global_model = GlobalModel()
print(local_model)
logits = local_model(torch.ones(1, 1))
print(logits)

print("#############")
loss = model_dist_norm_var(local_model, global_model)
print(loss)
print("loss = " + str(loss))

gradient = torch.autograd.grad(loss, local_model.trainable_parameters())
print(gradient)


print("************")
loss2 = model_dist_norm_var2(local_model, global_model)
print("loss2 = " + str(loss2))

gradient = torch.autograd.grad(loss2, local_model.trainable_parameters())
print(gradient)