# federated/server.py - FedPer-only server

from __future__ import annotations
from typing import Dict, List, Optional

import torch
import torch.nn as nn

__all__ = ["Server"]


def _should_aggregate_param(name: str) -> bool:
    """
    Determine if a parameter should be aggregated in FedPer.

    Excludes (Personalized):
    - head.* : personalized prediction heads
    - slot_embed.* : personalized QAP slot embeddings
    - projection.* : personalized projection layers (Projection mode)

    Includes (Shared):
    - backbone.* : shared backbone parameters
    - qap.* (except slot_embed) : shared QAP parameters (QAP mode)
    - time_proj.*, cat_proj.* : shared time feature processing (both modes)
    """
    # Exclude personalized heads
    if name.startswith("head") or ".head." in name:
        return False

    # Exclude personalized slot embeddings (QAP mode)
    if name.endswith("slot_embed.weight") or ".slot_embed." in name:
        return False

    # Exclude personalized projection layers (Projection mode)
    if name.startswith("projection.") or ".projection." in name:
        return False

    # Include everything else (backbone and shared parameters)
    return True


def fedper_aggregate(client_params: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """
    FedPer parameter aggregation.
    Only aggregates shared parameters (backbone + shared QAP).
    """
    if not client_params:
        return {}

    # Get keys for shared parameters only
    first_client = client_params[0]
    shared_keys = [k for k in first_client.keys() if _should_aggregate_param(k)]

    if not shared_keys:
        return {}

    aggregated = {}
    num_clients = len(client_params)

    for key in shared_keys:
        # Simple averaging
        param_sum = torch.zeros_like(first_client[key])
        for client_state in client_params:
            if key in client_state:
                param_sum += client_state[key]

        aggregated[key] = param_sum / num_clients

    return aggregated


class Server:
    """
    FedPer server that only manages shared parameters.
    """
    def __init__(self, global_backbone: Optional[nn.Module] = None):
        self.global_backbone = global_backbone
        self.global_state = {}
        self.round_num = 0

    def aggregate_and_update(self, client_params: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """
        Aggregate client parameters and update global model.

        Args:
            client_params: List of client state dictionaries (shared parameters only)

        Returns:
            Updated global state dictionary
        """
        if not client_params:
            return self.global_state

        # FedPer aggregation (only shared parameters)
        self.global_state = fedper_aggregate(client_params)
        self.round_num += 1

        # Update global backbone if available
        if self.global_backbone is not None and self.global_state:
            # Filter only backbone parameters for the global backbone
            backbone_params = {k: v for k, v in self.global_state.items()
                             if k.startswith("backbone.")}
            if backbone_params:
                current_sd = self.global_backbone.state_dict()
                current_sd.update(backbone_params)
                self.global_backbone.load_state_dict(current_sd)

        return self.global_state

    def broadcast_parameters(self) -> Dict[str, torch.Tensor]:
        """
        Return current global state for broadcasting to clients.
        """
        return self.global_state.copy() if self.global_state else {}

    def get_round(self) -> int:
        """Get current round number."""
        return self.round_num

    def save_state(self, path: str):
        """Save server state."""
        state = {
            'global_state': self.global_state,
            'round_num': self.round_num,
        }
        if self.global_backbone is not None:
            state['global_backbone'] = self.global_backbone.state_dict()

        torch.save(state, path)

    def load_state(self, path: str):
        """Load server state."""
        state = torch.load(path, map_location='cpu')
        self.global_state = state.get('global_state', {})
        self.round_num = state.get('round_num', 0)

        if self.global_backbone is not None and 'global_backbone' in state:
            self.global_backbone.load_state_dict(state['global_backbone'])