#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide model-averaging utilities for federated learning workflows."""

import copy
from typing import Dict, List, Union

import numpy as np
import torch
from torch import nn

# Type Aliases for clarity
Weights = Union[Dict[str, torch.Tensor], np.ndarray]
ClientUpdate = tuple[int, Dict[str, torch.Tensor]]


def average_weights(weights_list: List[Weights]) -> Weights:
    """Compute the simple average of a list of model weights.

    This function supports both PyTorch state dictionaries and NumPy arrays.

    Args:
        weights_list: A list of model weights to be averaged. Each element can
            be a PyTorch state_dict or a NumPy array.

    Returns:
        The averaged model weights in the same format as the input elements.

    Raises:
        ValueError: If the weights_list is empty.
    """
    if not weights_list:
        raise ValueError("The weights_list cannot be empty.")

    num_models = len(weights_list)

    # Case 1: NumPy arrays
    if isinstance(weights_list[0], np.ndarray):
        averaged_weights = np.zeros_like(weights_list[0], dtype=np.float32)
        for w in weights_list:
            averaged_weights += w.astype(np.float32)
        averaged_weights /= num_models
        return averaged_weights

    # Case 2: PyTorch state_dict
    elif isinstance(weights_list[0], dict):
        averaged_weights = {
            k: v.clone().detach().float()  # force float here
            for k, v in weights_list[0].items()
        }
        for key in averaged_weights.keys():
            for i in range(1, num_models):
                averaged_weights[key] += weights_list[i][key].float()
            averaged_weights[key] /= num_models
        return averaged_weights

    else:
        raise TypeError("Unsupported weight type: {}".format(type(weights_list[0])))


def average_weights_resilient(
    weights_list: List[Dict[str, torch.Tensor]], reputations: Union[List[float], np.ndarray, torch.Tensor]
) -> Dict[str, torch.Tensor]:
    """Compute a reputation-weighted average of model state dictionaries."""

    if weights_list is None or len(weights_list) == 0:
        raise ValueError("weights_list cannot be empty.")

    if reputations is None:
        raise ValueError("reputations cannot be None.")

    # Convert reputations safely
    if isinstance(reputations, torch.Tensor):
        reputations = reputations.detach().cpu().numpy()
    reputations = np.asarray(reputations, dtype=float)

    if reputations.size == 0:
        raise ValueError("reputations cannot be empty.")
    if len(weights_list) != len(reputations):
        raise ValueError("Mismatch between number of weights and reputations.")

    total_reputation = np.sum(reputations)
    if total_reputation <= 0:
        raise ValueError("Total reputation must be positive.")

    # --- Weighted average ---
    averaged_weights = {}
    for key in weights_list[0].keys():
        # Force float32 for safe accumulation
        averaged_weights[key] = weights_list[0][key].float() * (reputations[0] / total_reputation)

    for i in range(1, len(weights_list)):
        weight_factor = reputations[i] / total_reputation
        for key in averaged_weights.keys():
            averaged_weights[key] += weights_list[i][key].float() * weight_factor

    return averaged_weights


def average_fsvrg_weights(
    client_updates: List[ClientUpdate], aggregation_scalar: float, global_model: nn.Module, gpu_id: int = -1
) -> Dict[str, torch.Tensor]:
    """Update global model parameters using the FSVRG algorithm.

    FSVRG (Federated Stochastic Variance Reduced Gradient) uses a different
    aggregation scheme than simple averaging, involving pseudo-gradients.

    Args:
        client_updates: A list of tuples, where each tuple contains the number
            of local data points and the client's model state_dict.
        aggregation_scalar: A scalar factor (often the learning rate) for the
            aggregation step.
        global_model: The current global model (instance of nn.Module).
        gpu_id: The GPU device ID to use. A value of -1 indicates CPU usage.

    Returns:
        An updated global model state_dict.
    """
    device = torch.device(f"cuda:{gpu_id}" if gpu_id != -1 and torch.cuda.is_available() else "cpu")
    global_model.to(device)
    global_weights = copy.deepcopy(global_model.state_dict())

    total_data_size = sum(num_samples for num_samples, _ in client_updates)
    if total_data_size == 0:
        return global_weights  # No updates to apply.

    # Initialize a dictionary to store the sum of pseudo-gradients.
    pseudo_gradients = {key: torch.zeros_like(tensor, device=device) for key, tensor in global_weights.items()}

    # Calculate the sum of pseudo-gradients from all clients.
    for num_samples, local_weights in client_updates:
        for key in pseudo_gradients.keys():
            # Calculate difference between local and global weights on the same device.
            weight_diff = local_weights[key].to(device) - global_weights[key]
            pseudo_gradients[key] += weight_diff * num_samples

    # Update global weights using the aggregated pseudo-gradients.
    for key in global_weights.keys():
        update_term = (aggregation_scalar / total_data_size) * pseudo_gradients[key]
        global_weights[key] += update_term

    return global_weights
