import torch
import copy
from .pfpt_nonparametric_agg import NonparametricAgg

class PfptAgg:
    def aggregate(self, global_state: dict, buckets):
        if not buckets:
            return global_state
        
        out = copy.deepcopy(global_state)
        g_params = out["trainable"]
        
        total_weight = 0.0
        for p_name in g_params:
            if 'prompt_embeddings' not in p_name:
                g_params[p_name].zero_()

        for upd, weights in buckets:
            weight = weights["trainable"]["scalar"]
            total_weight += weight
            for p_name, p_val in upd["trainable"].items():
                if 'prompt_embeddings' not in p_name:
                    g_params[p_name] += p_val * weight
        
        if total_weight > 0:
            for p_name in g_params:
                if 'prompt_embeddings' not in p_name:
                    g_params[p_name] /= total_weight

        client_prompts = []
        for upd, _ in buckets:
            client_prompts.append(upd["trainable"]['prompt_embeddings'].squeeze(0))
        
        if client_prompts:
            client_prompts_tensor = torch.stack(client_prompts, dim=0).to(g_params['prompt_embeddings'].device)
            
            agg_module = NonparametricAgg(prompt_dim=client_prompts_tensor.shape[-1]).to(client_prompts_tensor.device)
            global_prompts = agg_module(client_prompts_tensor) 
            
            num_global_prompts = g_params['prompt_embeddings'].shape[1]
            if global_prompts.shape[0] < num_global_prompts:
                 padding = torch.zeros(num_global_prompts - global_prompts.shape[0], global_prompts.shape[1], device=global_prompts.device)
                 global_prompts = torch.cat([global_prompts, padding], dim=0)
            

            g_params['prompt_embeddings'].data.copy_(global_prompts[:num_global_prompts].unsqueeze(0))

        return {"trainable": g_params}