import copy, torch

class FedAvg:
    def aggregate(self, global_state: dict, buckets):
        if not buckets:
            return global_state
        out = {k: copy.deepcopy(v) for k, v in global_state.items()}

        for key in out.keys():
            some_param = next(iter(out[key].values()))
            dev = some_param.device

            num = {p: torch.zeros_like(out[key][p], device=dev) for p in out[key]}
            denom_scalar = torch.tensor(0.0, device=dev)
            row_mode = any(("row" in w.get(key, {})) for _, w in buckets)

            if row_mode:
                C = None
                for p, val in out[key].items():
                    if val.ndim >= 1:
                        C = val.size(0); break
                if C is None:
                    row_mode = False
                else:
                    denom_row = torch.zeros(C, dtype=torch.float32, device=dev)
                    for upd, w in buckets:
                        if key not in upd: continue
                        rw = w.get(key, {}).get("row", None)
                        if rw is None: continue
                        rw = torch.as_tensor(rw, device=dev, dtype=torch.float32)
                        for p in out[key]:
                            val = upd[key][p].to(dev)
                            if val.ndim == 2: num[p] += val * rw.view(-1,1)
                            elif val.ndim == 1: num[p] += val * rw
                        denom_row += rw
                    for p in out[key]:
                        val = out[key][p].to(dev)
                        if val.ndim == 2:
                            denom = denom_row.view(-1,1).clamp_min(1e-12)
                            avg = num[p] / denom
                            mask = (denom_row > 0).view(-1,1).to(avg.dtype)
                            out[key][p] = avg * mask + val * (1-mask)
                        elif val.ndim == 1:
                            denom = denom_row.clamp_min(1e-12)
                            avg = num[p] / denom
                            mask = (denom_row > 0).to(avg.dtype)
                            out[key][p] = avg * mask + val * (1-mask)
            if not row_mode:
                for upd, w in buckets:
                    if key not in upd: continue
                    s = float(w.get(key, {}).get("scalar", 0.0))
                    if s <= 0: continue
                    for p in out[key]:
                        num[p] += upd[key][p].to(dev) * s
                    denom_scalar += s
                if denom_scalar.item() > 0:
                    for p in out[key]:
                        out[key][p] = num[p] / denom_scalar
        return out
