#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide utilities for analysing federated model weights and gradients."""

import copy
from collections import OrderedDict
from typing import Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
from numpy import floating
from scipy.stats import norm, ttest_1samp
from torch import Tensor

# Type alias for a model's state dictionary
StateDict = Dict[str, torch.Tensor]


def _flatten_state_dict(state_dict: StateDict, ignore_keys: List[str] = None) -> np.ndarray:
    """Flatten a PyTorch state dictionary into a single one-dimensional NumPy array.

    Args:
        state_dict: The model's state dictionary.
        ignore_keys: A list of substrings. Keys containing these substrings
            (e.g., 'mean', 'var') will be ignored.

    Returns:
        A 1D NumPy array containing all parameter values from the state_dict.
    """
    if ignore_keys is None:
        ignore_keys = []

    tensors = []
    for key, tensor in state_dict.items():
        if any(sub in key for sub in ignore_keys):
            continue
        if tensor.is_cuda:
            tensor = tensor.cpu()
        tensors.append(tensor.detach().numpy().flatten())

    return np.concatenate(tensors) if tensors else np.array([])


def calculate_l2_norm(params_a: StateDict, params_b: StateDict = None) -> floating:
    """Calculate the L2 norm (Euclidean distance) between two state dictionaries.

    Args:
        params_a: The first state dictionary.
        params_b: The second state dictionary. If omitted, the norm of
            ``params_a`` is returned.

    Returns:
        The L2 norm as a float.
    """
    vec_a = _flatten_state_dict(params_a)
    if params_b is None:
        return np.linalg.norm(vec_a)

    vec_b = _flatten_state_dict(params_b)
    return np.linalg.norm(vec_a - vec_b)


def distance(w_local: Union[List[StateDict], StateDict], w_glob: StateDict) -> Dict[int, float]:
    """Compute L2 distances between client weights and the global reference model.

    Args:
        w_local: Collection of client state dictionaries or a single dictionary.
        w_glob: Global state dictionary used as the reference.

    Returns:
        Mapping from client index (or ``0`` for a singleton input) to the distance value.
    """
    if w_local is None:
        return OrderedDict()

    # Normalize single-state_dict input to list form
    if isinstance(w_local, dict):
        locals_list = [w_local]
    elif isinstance(w_local, (list, tuple)):
        locals_list = list(w_local)
    else:
        raise TypeError("w_local must be a state_dict or a list/tuple of state_dicts")

    # Validate global weights type
    if not isinstance(w_glob, dict):
        raise TypeError("w_glob must be a state_dict (dict)")

    dists = OrderedDict()
    for i, wl in enumerate(locals_list):
        try:
            # Prefer to use existing helper if available, otherwise fallback to local routine
            # Try calling a globally defined calculate_l2_norm if present
            if "calculate_l2_norm" in globals() and callable(globals()["calculate_l2_norm"]):
                # calculate_l2_norm may expect (a, b) in either order; keep original usage style
                try:
                    val = float(globals()["calculate_l2_norm"](wl, w_glob))
                except Exception:
                    # fallback to robust dict-based computation
                    val = calculate_l2_norm(wl, w_glob)
            else:
                val = calculate_l2_norm(wl, w_glob)
        except Exception:
            # In case of an error for a particular client, record inf and continue
            val = float("inf")
        dists[i] = float(val)

    return dists


def calculate_inner_product(params_a: StateDict, params_b: StateDict) -> float:
    """Calculate the inner product of two state dictionaries.

    Args:
        params_a: The first state dictionary.
        params_b: The second state dictionary.

    Returns:
        The inner product as a float.
    """
    sum_product = 0.0
    for key in params_a.keys():
        if key in params_b:
            vec_a = params_a[key].cpu().numpy()
            vec_b = params_b[key].cpu().numpy()
            sum_product += np.sum(np.multiply(vec_a, vec_b))
    return sum_product


def calculate_gradients(
    weights_before: StateDict,
    weights_after: Union[StateDict, List[StateDict]],
    learning_rate: float,
    exclude_key_substrings: Iterable[str] = ("running_mean", "running_var", "num_batches_tracked"),
    keep_keys_shape: bool = True,
) -> Union[StateDict, List[StateDict]]:
    """
    Calculates pseudo-gradients as (weights_before - weights_after) / learning_rate,
    while filtering out non-trainable buffers and non-floating tensors.

    - Keys containing any excluded substring are filtered out (e.g. BatchNorm statistics).
    - Non-floating dtype tensors are filtered out.
    - Shape mismatches or missing keys are handled safely:
        * If keep_keys_shape=True, they are replaced by zero tensors with the same shape.
        * Otherwise, such keys are omitted.

    Args:
        weights_before: The global model weights before the update.
        weights_after:  The local model weights after the update, or a list of them.
        learning_rate:  Learning rate used for the update (must be > 0).
        exclude_key_substrings: Substrings of keys that should be excluded.
        keep_keys_shape: If True, excluded keys are replaced by zero tensors;
                         if False, excluded keys are removed.

    Returns:
        A gradient state_dict (or list of them) with the same structure as weights_after.
    """
    if isinstance(weights_after, list):
        return [
            calculate_gradients(
                weights_before,
                w_after,
                learning_rate,
                exclude_key_substrings=exclude_key_substrings,
                keep_keys_shape=keep_keys_shape,
            )
            for w_after in weights_after
        ]

    if learning_rate <= 0:
        raise ValueError("learning_rate must be positive.")

    grads: StateDict = {}

    for key, w_before in weights_before.items():
        w_after = weights_after.get(key, None)

        # Rule 1: Exclude if the key name contains specified substrings
        name_excluded = any(sub in key for sub in exclude_key_substrings)

        # Rule 2: Exclude if missing or shape mismatch
        shape_mismatch = (w_after is None) or (w_before.shape != w_after.shape)

        # Rule 3: Only compute if both tensors are floating type
        is_float_pair = (
            (w_after is not None) and torch.is_floating_point(w_before) and torch.is_floating_point(w_after)
        )

        if (not name_excluded) and (not shape_mismatch) and is_float_pair:
            grads[key] = (w_before.to(torch.float32) - w_after.to(torch.float32)) / learning_rate
        else:
            if keep_keys_shape:
                # Replace with zero tensor to preserve consistent keys
                device = w_before.device
                dtype = torch.float32
                grads[key] = torch.zeros_like(w_before, dtype=dtype, device=device)
            # If keep_keys_shape=False, skip this key

    return grads


def average_gradients(gradients: List[StateDict]) -> StateDict:
    """Compute the element-wise average of a list of gradients.

    Args:
        gradients: A list of gradient state_dicts.

    Returns:
        A state_dict representing the average gradient.
    """
    if not gradients:
        return {}

    avg_grad = copy.deepcopy(gradients[0])
    num_grads = len(gradients)

    for key in avg_grad.keys():
        # Sum gradients from all other clients
        for i in range(1, num_grads):
            avg_grad[key] += gradients[i][key]
        # Divide by the total number of clients
        avg_grad[key] = torch.div(avg_grad[key], num_grads)
    return avg_grad


def calculate_gradient_std_dev(gradients: List[StateDict]) -> StateDict:
    """Calculate the element-wise standard deviation across a list of gradients.

    Args:
        gradients: A list of gradient state_dicts.

    Returns:
        A state_dict where each element is the standard deviation.
    """
    if len(gradients) < 2:
        raise ValueError("Standard deviation requires at least two gradients.")

    avg_grad = average_gradients(gradients)
    num_grads = len(gradients)

    # Initialize variance with zeros
    var_grad = {key: torch.zeros_like(tensor) for key, tensor in avg_grad.items()}

    # Calculate sum of squared differences from the mean
    for grad in gradients:
        for key in avg_grad.keys():
            diff = grad[key] - avg_grad[key]
            var_grad[key] += diff * diff

    # Compute standard deviation
    std_grad = {}
    for key in avg_grad.keys():
        # Use Bessel's correction (n-1) for sample standard deviation
        std_grad[key] = torch.sqrt(var_grad[key] / (num_grads - 1))

    return std_grad


def normalize_gradients(
    gradients: Union[StateDict, List[StateDict]],
    grad_mean: StateDict,
    grad_std: StateDict,
    ignore_keys: List[str] = None,
) -> Union[StateDict, List[StateDict]]:
    """Normalise gradients using a provided mean and standard deviation.

    Args:
        gradients: A single gradient state_dict or a list of them.
        grad_mean: The average gradient state_dict.
        grad_std: The standard deviation state_dict.
        ignore_keys: A list of substrings for keys to ignore during normalization.

    Returns:
        The normalized gradients.
    """
    if ignore_keys is None:
        ignore_keys = ["tracked"]

    if isinstance(gradients, list):
        return [normalize_gradients(grad, grad_mean, grad_std, ignore_keys) for grad in gradients]

    normalized_grad = copy.deepcopy(gradients)
    for key in grad_mean.keys():
        if any(sub in key for sub in ignore_keys):
            continue
        # Add a small epsilon to avoid division by zero
        normalized_grad[key] = torch.true_divide(gradients[key] - grad_mean[key], grad_std[key] + 1e-8)
    return normalized_grad


def calculate_cosine_similarity_to_normal(
    weights: Union[StateDict, List[StateDict]], num_bins: int, bin_range: Tuple[float, float] = (-20.0, 20.0)
) -> float:
    """Calculate the cosine similarity between the weight distribution
    and a standard normal distribution.

    Args:
        weights: A single state_dict or a list of them.
        num_bins: The number of bins to use for creating the distribution histogram.
        bin_range: The min and max range for binning the weights.

    Returns:
        The cosine similarity score.
    """
    if isinstance(weights, list):
        all_weights_flat = np.concatenate(
            [_flatten_state_dict(w, ignore_keys=["mean", "var", "tracked"]) for w in weights]
        )
    else:
        all_weights_flat = _flatten_state_dict(weights, ignore_keys=["mean", "var", "tracked"])

    if all_weights_flat.size == 0:
        return 0.0

    # Create bins and get empirical distribution
    bins = np.linspace(bin_range[0], bin_range[1], num_bins + 1)
    hist = np.histogram(all_weights_flat, bins=bins)[0]
    empirical_prob = hist / hist.sum()

    # Get theoretical probabilities from a standard normal distribution
    normal_cdf = norm.cdf(bins)
    theoretical_prob = np.diff(normal_cdf)

    # Calculate cosine similarity
    dot_product = np.dot(empirical_prob, theoretical_prob)
    norm_empirical = np.linalg.norm(empirical_prob)
    norm_theoretical = np.linalg.norm(theoretical_prob)

    if norm_empirical == 0 or norm_theoretical == 0:
        return 0.0

    return dot_product / (norm_empirical * norm_theoretical)


def perform_t_test(
    data_vector: np.ndarray, target_indices: List[int], significance_level: float = 0.05
) -> Dict[int, Tuple[float, bool]]:
    """Perform a one-sample t-test for selected elements against the remainder.

    For each target index, its value is tested against the population formed
    by all other elements in the data_vector.

    Args:
        data_vector: A 1D NumPy array of data.
        target_indices: A list of indices to test.
        significance_level: The alpha level for the test.

    Returns:
        A dictionary mapping each target index to its p-value and a boolean
        indicating if the result is statistically significant.
    """
    data_vector = np.array(data_vector)
    all_indices = set(range(len(data_vector)))
    comparison_indices = list(all_indices - set(target_indices))
    comparison_group = data_vector[comparison_indices]

    results = {}
    for i in target_indices:
        target_value = data_vector[i]
        _, p_value = ttest_1samp(comparison_group, target_value, nan_policy="omit")
        results[i] = (p_value, p_value < significance_level)
    return results


def delta_parameters(
    current: List[Dict[str, torch.Tensor]],
    previous: List[Dict[str, torch.Tensor]],
) -> List[Dict[str, torch.Tensor]]:
    if current is None or previous is None:
        raise ValueError("current and previous cannot be noney")
    if len(current) != len(previous):
        raise ValueError("current and previous numbers should be the same")

    deltas: List[Dict[str, torch.Tensor]] = []
    with torch.no_grad():
        for i, (c_dict, p_dict) in enumerate(zip(current, previous)):
            if c_dict.keys() != p_dict.keys():
                raise ValueError("Client i's key is not the same")
            out_i: Dict[str, torch.Tensor] = {}
            for k in c_dict.keys():
                ck, pk = c_dict[k], p_dict[k]
                if torch.is_floating_point(ck) and torch.is_floating_point(pk):
                    out_i[k] = ck - pk
                else:
                    out_i[k] = torch.zeros_like(ck)
            deltas.append(out_i)
    return deltas


def _is_float_tensor(t: Tensor) -> bool:
    return torch.is_tensor(t) and torch.is_floating_point(t)


def _flatten_float_params(param_dict: Dict[str, Tensor]) -> Tensor:
    vecs: List[Tensor] = []
    dev: Optional[torch.device] = None
    for v in param_dict.values():
        if _is_float_tensor(v):
            if dev is None:
                dev = v.device
            vecs.append(v.detach().reshape(-1).to(dev))
    if not vecs:
        return torch.zeros(1)
    return torch.cat(vecs, dim=0)


def _delta_params_vec(curr: Dict[str, Tensor], prev: Dict[str, Tensor]) -> Tensor:
    keys = [k for k in curr.keys() if (k in prev and _is_float_tensor(curr[k]) and _is_float_tensor(prev[k]))]
    if not keys:
        return torch.zeros(1)
    dev = curr[keys[0]].device
    return torch.cat([(curr[k] - prev[k]).detach().reshape(-1).to(dev) for k in keys], dim=0)
