import torch
from models.Fed import FedAvg_fsu

def create_state(val):
    return {"w": torch.tensor(float(val))}


def test_fedavg_fsu_multi_clients():
    w = [create_state(0.0), create_state(1.0), create_state(2.0)]
    out = FedAvg_fsu(w, [0, 1], lamda=0.5)
    expected = (w[0]["w"] + w[1]["w"] + 0.5 * w[2]["w"]) / (2 + 1 * 0.5)
    assert torch.allclose(out["w"], expected)
