import torch
import copy

class FMoEAgg:
    def aggregate(self, global_state: dict, buckets):
        if not buckets:
            return global_state
        
        out = copy.deepcopy(global_state)
        g_params = out["trainable"]
        
        for p_name in g_params:
            g_params[p_name].zero_()

        total_samples = 0.0
        for _, weights in buckets:
            total_samples += weights["trainable"]["scalar"]

        for upd, weights in buckets:
            weight = weights["trainable"]["scalar"]
            for p_name, p_val in upd["trainable"].items():
                if p_name in g_params:
                    g_params[p_name] += p_val.to(g_params[p_name].device) * weight
        
        if total_samples > 0:
            for p_name in g_params:
                g_params[p_name] /= total_samples

        return out