#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Implement common Byzantine attack strategies for federated learning."""

import copy
from typing import Dict, List, Set

import torch
from scipy.stats import norm

from .utils.federated_metrics import (
    average_gradients,
    calculate_gradient_std_dev,
    calculate_gradients,
    calculate_l2_norm,
    distance,
)
from .utils.model_averaging import average_weights

# Type alias for a model's state dictionary
StateDict = Dict[str, torch.Tensor]
WeightsList = List[StateDict]


def get_malicious_client_updates(
    all_updates: WeightsList,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
) -> WeightsList:
    """Return the subset of client updates that originate from malicious clients."""
    malicious_updates = []
    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            malicious_updates.append(all_updates[i])
    return malicious_updates


def gaussian_attack(
    client_updates: WeightsList,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
    attack_scale: float,
    device: torch.device,
) -> WeightsList:
    """
    Inject Gaussian noise into the model updates submitted by malicious clients.

    Args:
        client_updates: Model weight state dictionaries for all participating clients.
        all_client_indices: Indices of the clients selected for the current round.
        malicious_client_indices: Indices identifying the adversarial clients.
        attack_scale: Standard deviation of the Gaussian perturbation.
        device: Target device (`cpu` or `cuda`) used for tensor allocation.

    Returns:
        A list of state dictionaries where malicious entries have been noised.
    """
    attacked_updates = copy.deepcopy(client_updates)
    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            for key in attacked_updates[i].keys():
                # --- FIX ---
                # Only add noise to floating-point tensors to avoid errors
                # with integer buffers like 'num_batches_tracked'.
                if attacked_updates[i][key].dtype.is_floating_point:
                    noise = torch.randn_like(attacked_updates[i][key]) * attack_scale
                    attacked_updates[i][key] += noise.to(device)
    return attacked_updates


def lie_attack(
    client_updates: WeightsList,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
    learning_rate: float,
) -> WeightsList:
    """
    Construct Little Is Enough (LIE) attack updates from benign statistics.

    The malicious clients mimic the benign mean update and add a deviation
    scaled by the benign standard deviation to bias the aggregation outcome.

    Args:
        client_updates: Model updates received from all selected clients.
        all_client_indices: Indices of the clients selected for the current round.
        malicious_client_indices: Indices identifying the adversarial clients.
        learning_rate: Learning rate used during local optimisation.

    Returns:
        A list of state dictionaries with malicious entries replaced by the LIE update.
    """
    attacked_updates = copy.deepcopy(client_updates)
    num_clients = len(all_client_indices)
    num_attackers = len(malicious_client_indices)

    # Separate benign updates from all updates
    benign_updates = [
        update for i, update in enumerate(client_updates) if all_client_indices[i] not in malicious_client_indices
    ]

    if not benign_updates:
        return attacked_updates  # No benign clients to base attack on

    # Calculate the mean and std dev of benign updates
    avg_benign_update = average_gradients(benign_updates)
    std_dev_benign = calculate_gradient_std_dev(benign_updates)

    # Calculate the z-score for the attack magnitude
    s = (num_clients // 2 + 1) - num_attackers
    z_score = norm.ppf((num_clients - num_attackers - s) / (num_clients - num_attackers))

    # Craft the malicious update
    malicious_update = {}
    for key in avg_benign_update.keys():
        malicious_update[key] = avg_benign_update[key] + std_dev_benign[key] * z_score

    # Apply the malicious update to all attackers
    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            attacked_updates[i] = malicious_update

    return attacked_updates


def fang_attack(
    client_updates: WeightsList,
    global_weights_before: StateDict,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
) -> WeightsList:
    """
    Execute the Fang et al. model-poisoning attack against the aggregation step.

    The adversary derives element-wise extrema from the benign updates and
    crafts a replacement update that pushes the aggregated parameters in the
    opposite direction of the benign consensus.

    Args:
        client_updates: Model updates submitted by all selected clients.
        global_weights_before: Global weights prior to the current round.
        all_client_indices: Indices of the clients selected for the round.
        malicious_client_indices: Indices identifying the adversarial clients.

    Returns:
        A list of state dictionaries with malicious entries replaced by the crafted attack.
    """
    attacked_updates = copy.deepcopy(client_updates)
    benign_updates = [
        update for i, update in enumerate(client_updates) if all_client_indices[i] not in malicious_client_indices
    ]

    if not benign_updates:
        return attacked_updates

    # Calculate the aggregated benign model
    aggregated_benign_model = average_weights(benign_updates)

    # Calculate element-wise min/max across all benign updates for each parameter
    param_mins, param_maxs = {}, {}
    for key in benign_updates[0].keys():
        # Stack all tensors for the current key along a new dimension
        stacked_tensors = torch.stack([update[key] for update in benign_updates])
        # Compute min and max along that new dimension
        param_mins[key] = torch.min(stacked_tensors, dim=0).values
        param_maxs[key] = torch.max(stacked_tensors, dim=0).values

    # Determine the direction of the benign update
    update_direction = {}
    for key in aggregated_benign_model.keys():
        update_direction[key] = torch.sign(aggregated_benign_model[key] - global_weights_before[key])

    # Craft the malicious update
    malicious_update = {}
    for key in param_mins.keys():
        # If the benign update was positive (sign=1), attack with the min value.
        # If the benign update was negative (sign=-1), attack with the max value.
        malicious_update[key] = torch.where(update_direction[key] == 1, param_mins[key], param_maxs[key])

    # Apply the attack to all malicious clients
    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            attacked_updates[i] = malicious_update

    return attacked_updates


def attack_minmax(
    client_updates: List[StateDict],  # all client model updates
    global_weights_before: StateDict,  # previous global model weights
    all_client_indices: List[int],  # all selected clients this round
    malicious_client_indices: Set[int],  # malicious client set
    learning_rate: float,  # learning rate
    attacker_ability: str = "Full",  # "Part" or "Full"
) -> List[StateDict]:
    """
    Apply the Min-Max attack by maximising the deviation between malicious and benign gradients.

    Args:
        client_updates: Local state dictionaries produced in the current round.
        global_weights_before: Global reference weights from the previous round.
        all_client_indices: Indices of all clients selected in the round.
        malicious_client_indices: Indices that identify adversarial clients.
        learning_rate: Learning rate used during local optimisation.
        attacker_ability: Specifies whether adversaries see all updates ("Full")
            or only their own ("Part").

    Returns:
        A list of state dictionaries where malicious entries follow the Min-Max attack.
    """

    attacked_updates = [{k: v.clone() for k, v in client.items()} for client in client_updates]
    malicious_users = set(all_client_indices) & set(malicious_client_indices)

    # Select malicious clients' weights
    if attacker_ability == "Part":
        w_original = [
            {kk: vv.clone() for kk, vv in client_updates[all_client_indices.index(l)].items()}
            for k, l in enumerate(malicious_users)
        ]
    else:
        w_original = [{k: v.clone() for k, v in client.items()} for client in client_updates]

    # Compute gradients
    grad = calculate_gradients(global_weights_before, w_original, learning_rate)

    # Compute max distance
    dist_max = {i: max(distance(grad, grad[i]).values()) for i in range(len(grad))}

    # Average gradient + normalized direction
    grad_avg = average_gradients(grad)
    norm_value = calculate_l2_norm(grad_avg)
    nabla_p = {k: -grad_avg[k] / norm_value for k in grad_avg}

    # Binary search gamma
    gamma, gamma_last, step = 0.5, 2.5, 100.0
    while abs(gamma - gamma_last) > 1e-3:
        nabla_m = {k: nabla_p[k] * gamma + grad_avg[k] for k in nabla_p}
        delta_norms = {i: calculate_l2_norm(grad[i], nabla_m) for i in range(len(grad))}

        gamma_last = gamma
        if max(delta_norms.values()) < max(dist_max.values()):
            gamma += step / 2
        else:
            gamma -= step / 2
        step /= 2

    # Apply malicious updates
    for l in malicious_users:
        idx = all_client_indices.index(l)
        for key in nabla_m:
            attacked_updates[idx][key] = global_weights_before[key] - nabla_m[key] * learning_rate

    return attacked_updates


def attack_minsum(
    client_updates: List[StateDict],
    global_weights_before: StateDict,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
    learning_rate: float,
    attacker_ability: str = "Full",
) -> List[StateDict]:
    """
    Apply the Min-Sum attack by minimising the aggregate distance to malicious gradients.

    Args:
        client_updates: Local state dictionaries produced in the current round.
        global_weights_before: Global reference weights from the previous round.
        all_client_indices: Indices of all clients selected in the round.
        malicious_client_indices: Indices that identify adversarial clients.
        learning_rate: Learning rate used during local optimisation.
        attacker_ability: Specifies whether adversaries see all updates ("Full")
            or only their own ("Part").

    Returns:
        A list of state dictionaries where malicious entries follow the Min-Sum attack.
    """

    attacked_updates = [{k: v.clone() for k, v in client.items()} for client in client_updates]
    malicious_users = set(all_client_indices) & set(malicious_client_indices)

    # Select malicious clients' weights
    if attacker_ability == "Part":
        w_original = [
            {kk: vv.clone() for kk, vv in client_updates[all_client_indices.index(l)].items()}
            for k, l in enumerate(malicious_users)
        ]
    else:
        w_original = [{k: v.clone() for k, v in client.items()} for client in client_updates]

    # Compute gradients
    grad = calculate_gradients(global_weights_before, w_original, learning_rate)

    # Compute sum of distances
    dist_sum = {i: sum(distance(grad, grad[i]).values()) for i in range(len(grad))}

    # Average gradient + normalized direction
    grad_avg = average_gradients(grad)
    norm_value = calculate_l2_norm(grad_avg)
    nabla_p = {k: -grad_avg[k] / norm_value for k in grad_avg}

    # Binary search gamma
    gamma, gamma_last, step = 0.5, 2.5, 100.0
    while abs(gamma - gamma_last) > 1e-3:
        nabla_m = {k: nabla_p[k] * gamma + grad_avg[k] for k in nabla_p}
        delta_norms = {i: calculate_l2_norm(grad[i], nabla_m) for i in range(len(grad))}
        delta_norms_sq = {i: v**2 for i, v in delta_norms.items()}

        gamma_last = gamma
        if sum(delta_norms_sq.values()) < max(dist_sum.values()) ** 2:
            gamma += step / 2
        else:
            gamma -= step / 2
        step /= 2

    # Apply malicious updates
    for l in malicious_users:
        idx = all_client_indices.index(l)
        for key in nabla_m:
            attacked_updates[idx][key] = global_weights_before[key] - nabla_m[key] * learning_rate

    return attacked_updates
