import torch
import copy

class FLoRAAgg:
    def aggregate(self, global_state: dict, buckets):
        if not buckets:
            return global_state
        
        new_backbone_state = copy.deepcopy(global_state["model"])
        
        if "head" in global_state:
            new_head_state = copy.deepcopy(global_state["head"])
            for k in new_head_state:
                new_head_state[k].zero_()
        else:
            new_head_state = None

        total_samples = sum(w["scalar"] for _, w in buckets)
        if total_samples == 0: return global_state

        delta_accum = {}
        
        for upd, w in buckets:
            pk = w["scalar"] / total_samples
            if new_head_state is not None and "head" in upd:
                for k in new_head_state:
                    new_head_state[k] += upd["head"][k].to(new_head_state[k].device) * pk

            lora_state = upd["lora"]
            meta = upd["meta"]
            rank = meta["rank"]
            alpha = meta["alpha"]
            scaling = alpha / rank
            
            modules = {}
            for k, v in lora_state.items():
                if "lora_A" in k:
                    mod_name = k.split(".lora_A")[0]
                    if mod_name not in modules: modules[mod_name] = {}
                    modules[mod_name]["A"] = v
                elif "lora_B" in k:
                    mod_name = k.split(".lora_B")[0]
                    if mod_name not in modules: modules[mod_name] = {}
                    modules[mod_name]["B"] = v
            
            for mod_name, mats in modules.items():
                if "A" not in mats or "B" not in mats: continue
                
                lora_A = mats["A"].to(torch.float32)
                lora_B = mats["B"].to(torch.float32)
                
                delta_w_k = (lora_B @ lora_A) * scaling
                contribution = delta_w_k * pk
                
                target_key_candidates = [
                    mod_name.replace("base_model.model.", "") + ".weight",
                    mod_name + ".weight"
                ]
                
                target_key = None
                for key in target_key_candidates:
                    if key in new_backbone_state:
                        target_key = key
                        break
                
                if target_key:
                    if target_key not in delta_accum:
                        delta_accum[target_key] = torch.zeros_like(new_backbone_state[target_key], dtype=torch.float32)
                    
                    if delta_accum[target_key].shape == contribution.shape:
                        delta_accum[target_key] += contribution.to(delta_accum[target_key].device)
                    else:
                        try: delta_accum[target_key] += contribution.view_as(delta_accum[target_key]).to(delta_accum[target_key].device)
                        except: pass

        for key, delta in delta_accum.items():
            if key in new_backbone_state:
                new_backbone_state[key] += delta.to(new_backbone_state[key].dtype).to(new_backbone_state[key].device)
        
        return {
            "model": new_backbone_state,
            "head": new_head_state
        }