#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Implement defense mechanisms that harden federated learning against Byzantine attacks."""

import copy
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import torch
from torch import Tensor

from .utils.federated_metrics import (
    _delta_params_vec,
    calculate_gradients,
    calculate_inner_product,
    calculate_l2_norm,
    perform_t_test,
)
from .utils.model_averaging import average_weights

# Type alias for a model's state dictionary
StateDict = Dict[str, torch.Tensor]
WeightsList = List[StateDict]


def update_reputations(
    participating_client_indices: List[int],
    detected_malicious_indices: Set[int],
    current_alphas: np.ndarray,
    current_betas: np.ndarray,
    discount_factor: float = 1.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Update per-client Beta distribution parameters after a detection round.

    Args:
        participating_client_indices: Indices of clients that contributed in the round.
        detected_malicious_indices: Indices that were flagged as malicious.
        current_alphas: Current alpha parameters for all registered clients.
        current_betas: Current beta parameters for all registered clients.
        discount_factor: Multiplicative decay applied to historical counts.

    Returns:
        Tuple of (reputations, updated_alphas, updated_betas).
    """
    updated_alphas = current_alphas.copy()
    updated_betas = current_betas.copy()
    reputations = np.ones(len(participating_client_indices), dtype=float)

    for i, client_idx in enumerate(participating_client_indices):
        # Apply decay to historical values
        alpha = discount_factor * current_alphas[client_idx]
        beta = discount_factor * current_betas[client_idx]

        if client_idx in detected_malicious_indices:
            # Detected as malicious: increment beta (failure count)
            updated_alphas[client_idx] = alpha
            updated_betas[client_idx] = beta + 1
        else:
            # Assumed benign: increment alpha (success count)
            updated_alphas[client_idx] = alpha + 1
            updated_betas[client_idx] = beta

        # Calculate the new reputation score for the current round
        reputations[i] = updated_alphas[client_idx] / (updated_alphas[client_idx] + updated_betas[client_idx])

    return reputations, updated_alphas, updated_betas


def calculate_gradient_cosine_similarities(
    local_weights: WeightsList,
    server_weights: StateDict,
    global_weights_before: StateDict,
    learning_rate: float,
) -> np.ndarray:
    """
    Compute cosine similarities between each client gradient and the aggregated gradient.

    Args:
        local_weights: Locally updated model weights from participating clients.
        server_weights: Aggregated global weights produced after the round.
        global_weights_before: Global weights prior to the current round.
        learning_rate: Learning rate used for local optimisation.

    Returns:
        NumPy array containing one similarity score per client.
    """
    local_grads = calculate_gradients(global_weights_before, local_weights, learning_rate)
    server_grad = calculate_gradients(global_weights_before, server_weights, learning_rate)

    similarities = []
    norm_server_grad_sq = calculate_inner_product(server_grad, server_grad)

    for grad in local_grads:
        inner_prod = calculate_inner_product(grad, server_grad)
        norm_local_grad_sq = calculate_inner_product(grad, grad)

        denominator = np.sqrt(norm_local_grad_sq * norm_server_grad_sq)
        if denominator == 0:
            similarity = 0.0
        else:
            similarity = inner_prod / denominator
        similarities.append(similarity)

    similarities = np.array(similarities)
    return similarities


def aggregate_multi_krum(
    local_weights: WeightsList,
    global_weights_before: StateDict,
    num_attackers: int,
    num_benign_to_select: int,
    learning_rate: float,
) -> StateDict:
    """
    Execute Multi-Krum aggregation to select reliable client updates.

    Args:
        local_weights: Locally updated model weights from participating clients.
        global_weights_before: Global weights prior to the current round.
        num_attackers: Expected number of Byzantine adversaries.
        num_benign_to_select: Number of benign candidates to average.
        learning_rate: Learning rate used for local optimisation.

    Returns:
        Aggregated state dictionary produced by the Multi-Krum rule.
    """
    num_clients = len(local_weights)
    if num_clients <= num_attackers:
        # Not enough clients to perform Krum, return simple average
        return average_weights(local_weights)

    # Calculate gradients from weights
    gradients = calculate_gradients(global_weights_before, local_weights, learning_rate)

    # Calculate pairwise squared Euclidean distances between all gradients
    distances = np.zeros((num_clients, num_clients))
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            dist = calculate_l2_norm(gradients[i], gradients[j]) ** 2
            distances[i, j] = distances[j, i] = dist

    # For each client, calculate the sum of distances to its k nearest neighbors
    k = num_clients - num_attackers - 2
    scores = np.zeros(num_clients)
    for i in range(num_clients):
        sorted_dists = np.sort(distances[i])
        scores[i] = np.sum(sorted_dists[1 : k + 1])  # Exclude distance to self

    # Select the clients with the lowest scores
    selected_indices = np.argsort(scores)[:num_benign_to_select]
    selected_weights = [local_weights[i] for i in selected_indices]

    # Return the average of the selected weights
    return average_weights(selected_weights)


def detect_by_gradient_norm_outliers(
    local_weights: WeightsList,
    global_weights_before: StateDict,
    learning_rate: float,
    outlier_threshold: float = 1.04,
) -> List[int]:
    """
    Flag clients whose layer-wise gradient norms deviate beyond the configured threshold.

    This implementation follows the FGNV defence intuition of spotting gradient
    magnitude outliers.

    Args:
        local_weights: Client model updates subject to evaluation.
        global_weights_before: Global weights from the previous round.
        learning_rate: Learning rate used during local optimisation.
        outlier_threshold: Multiplicative band around the median similarity.

    Returns:
        Indices of clients considered potential adversaries.
    """
    gradients = calculate_gradients(global_weights_before, local_weights, learning_rate)
    num_clients = len(gradients)
    if num_clients == 0:
        return []

    param_keys = gradients[0].keys()
    suspicion_matrix = np.zeros((num_clients, len(param_keys)))

    for i, key in enumerate(param_keys):
        # Calculate the L2 norm of each client's gradient for the current layer
        layer_norms = np.array([torch.norm(grad[key]).item() for grad in gradients])

        # Calculate pairwise norm ratios
        # Add a small epsilon to avoid division by zero
        norm_ratios = layer_norms[:, np.newaxis] / (layer_norms[np.newaxis, :] + 1e-8)

        # Calculate the "similarity" score for each client
        similarity_scores = np.sum(norm_ratios, axis=1) / num_clients

        # Identify outliers based on the median score
        median_score = np.median(similarity_scores)
        is_outlier = (similarity_scores > median_score * outlier_threshold) | (
            similarity_scores < median_score / outlier_threshold
        )

        suspicion_matrix[:, i] = is_outlier.astype(int)

    malicious_flags = np.all(suspicion_matrix == 1, axis=1)
    malicious_indices = np.where(malicious_flags)[0].tolist()

    return malicious_indices


def malicious_detection_candidate(
    cosine_similar: List[float], chosenUsers: List[int], cos_threshold: [float]
) -> Tuple[List[int], List[int]]:
    """
    Identify clients whose cosine similarity falls below the configured threshold.

    Args:
        cosine_similar: Cosine similarity scores aligned with `chosenUsers`.
        chosenUsers: Global client identifiers selected for the round.
        cos_threshold: Threshold below which a client is considered suspicious.

    Returns:
        Tuple of (malicious_client_ids, malicious_indices_relative).
    """
    if cosine_similar is None or len(cosine_similar) == 0:
        return [], []

    cosine_arr = np.asarray(cosine_similar, dtype=float)
    malicious_indices = np.where(cosine_arr < cos_threshold)[0].tolist()  # positions in chosenUsers

    if len(malicious_indices) == 0:
        return [], []
    # map positions to global ids in chosenUsers
    malicious_ids = [chosenUsers[i] for i in malicious_indices]
    return malicious_ids, malicious_indices


def client_detection(
    list_acc: List[float], list_loss: List[float], chosen_users: List[int], significance: float
) -> List[int]:
    """
    Confirm malicious clients by contrasting their metrics with the benign aggregate.

    Args:
        list_acc: Accuracy traces where the final entry corresponds to the benign aggregate.
        list_loss: Loss traces aligned with `list_acc`.
        chosen_users: Global client identifiers considered in the round.
        significance: Significance level used for the statistical test fallback.

    Returns:
        Global client identifiers that should be treated as malicious.
    """
    if not list_loss or len(list_loss) < 2:
        # nothing to compare
        return []

    # last elements are benign aggregate metrics
    benign_loss = np.array(list_loss[-1])
    malicious_indices_rel = []
    # for each candidate (all items except last)
    for i in range(len(list_loss) - 1):
        cand_loss = np.array(list_loss[i])
        # original rule: if min(cand_loss - benign_loss) > 0 then malicious
        diff = cand_loss - benign_loss
        if diff.size == 0:
            continue
        if np.min(diff) > 0:
            malicious_indices_rel.append(i)
        else:
            neg_indices = np.where(diff < 0)[0]
            res = perform_t_test(list_loss[i], neg_indices, significance)
            has_significant = all(sig for (_, sig) in res.values())
            if not has_significant:
                malicious_indices_rel.append(copy.deepcopy(i))

    if not malicious_indices_rel:
        return []

    # map relative indices -> global chosenUsers ids
    malicious_ids = [chosen_users[i] for i in malicious_indices_rel]
    return malicious_ids


def client_reputation(
    discount, malicious: List[int], alpha: np.ndarray, beta: np.ndarray, chosenUsers: List[int]
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Update reputation parameters and derive per-client scores for the current round.

    Args:
        discount: Exponential decay applied to historical counts.
        malicious: Global client identifiers labelled as malicious.
        alpha: Alpha parameters tracked per global client.
        beta: Beta parameters tracked per global client.
        chosenUsers: Global client identifiers participating in this round.

    Returns:
        Tuple of (reputation scores, updated alpha parameters, updated beta parameters).
    """
    # defensive copies
    if alpha is None or beta is None:
        raise ValueError("alpha and beta must be provided as numpy arrays")

    alpha_last = np.array(alpha, copy=True)
    beta_last = np.array(beta, copy=True)
    alpha_update = alpha_last.copy()
    beta_update = beta_last.copy()

    num_chosen = len(chosenUsers)
    rep = np.ones(num_chosen, dtype=float)

    # ensure discount param

    malicious_set = set(malicious or [])

    for i, uid in enumerate(chosenUsers):
        # uid is global client id; update global alpha/beta at index uid
        if uid in malicious_set:
            # penalize malicious: decay alpha, increment beta
            alpha_update[uid] = discount * alpha_last[uid]
            beta_update[uid] = discount * beta_last[uid] + 1.0
            # reputation for this chosen-user position set to 0 (hard downweight)
            rep[i] = 0.0
        else:
            # reward benign: decay old counts then add one to alpha
            alpha_update[uid] = discount * alpha_last[uid] + 1.0
            beta_update[uid] = discount * beta_last[uid]
            denom = alpha_update[uid] + beta_update[uid]
            rep[i] = (alpha_update[uid] / denom) if denom > 0 else 0.0

    return rep, alpha_update, beta_update


def _laplace_like(
    t: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
    dev = t.device if device is None else device
    dt = dtype if dtype is not None else (t.dtype if torch.is_floating_point(t) else torch.float32)

    e = torch.empty(t.shape, device=dev, dtype=dt).exponential_()  # Exp(1)
    s = torch.randint(0, 2, t.shape, device=dev, dtype=torch.int8)  # {0,1}
    s = s.to(dt).mul_(2).sub_(1)  # -> {-1, +1}
    return e.mul_(s)  # Laplace(0,1)


def wbc(
    lr: float,
    w: List[Dict[str, torch.Tensor]],
    w_before: List[Dict[str, torch.Tensor]],
    delta_w_before: List[Dict[str, torch.Tensor]],
    device: Optional[torch.device] = None,
) -> List[Dict[str, torch.Tensor]]:
    if w is None or w_before is None or delta_w_before is None:
        raise ValueError("w, w_before, delta_w_before 不能为空")
    if not (len(w) == len(w_before) == len(delta_w_before)):
        raise ValueError("w, w_before, delta_w_before 的长度必须一致")

    for i, (wi, wbi, dwbi) in enumerate(zip(w, w_before, delta_w_before)):
        if wi.keys() != wbi.keys() or wi.keys() != dwbi.keys():
            raise ValueError(f"第 {i} 个客户端的参数键不一致")

    w_new: List[Dict[str, torch.Tensor]] = []

    with torch.no_grad():
        for wi, wbi, dwbi in zip(w, w_before, delta_w_before):
            out_dict: Dict[str, torch.Tensor] = {}
            for k, wt in wi.items():
                if not torch.is_floating_point(wt):
                    out_dict[k] = wt.detach().clone()
                    continue

                dev = device if device is not None else wt.device

                s = _laplace_like(wt, device=dev, dtype=wt.dtype).mul_(lr)

                delta = (wt - wbi[k] - dwbi[k]).abs()

                mask = delta <= s

                s.masked_fill_(~mask, 0)

                out_dict[k] = wt + s

            w_new.append(out_dict)

    return w_new


class FLDetector:
    def __init__(
        self,
        window_size: int = 10,
        kmax: int = 2,
        b_ref: int = 10,
        ridge: float = 1e-6,
        start_iter: int = 10,
    ):
        self.N = int(window_size)
        self.Kmax = int(kmax)
        self.B = int(b_ref)
        self.ridge = float(ridge)
        self.start_iter = int(start_iter)

        self._dW_hist: List[Tensor] = []
        self._dG_hist: List[Tensor] = []

        self._d_norm_window: List[Tensor] = []

        self._last_client_update: Dict[int, Tensor] = {}

        self._last_global_update: Optional[Tensor] = None

        self._t: int = 0

    def _hvp_ls(self, v: Tensor) -> Tensor:
        if not self._dW_hist or not self._dG_hist:
            return torch.zeros_like(v)

        S = torch.stack(self._dW_hist, dim=1)
        Y = torch.stack(self._dG_hist, dim=1)

        G = S.transpose(0, 1) @ S
        m = G.shape[0]
        G = G + self.ridge * torch.eye(m, device=G.device, dtype=G.dtype)
        r = S.transpose(0, 1) @ v

        try:
            L = torch.linalg.cholesky(G)
            alpha = torch.cholesky_solve(r.unsqueeze(1), L).squeeze(1)
        except RuntimeError:
            alpha = torch.linalg.solve(G, r)

        return Y @ alpha

    @torch.no_grad()
    def step_and_detect(
        self,
        *,
        chosen_users: List[int],
        local_weights: List[Dict[str, Tensor]],
        global_weights_before: Dict[str, Tensor],
        global_weights_after: Dict[str, Tensor],
        lr: float,
        global_update_vec: Optional[Tensor] = None,
    ) -> Tuple[List[int], np.ndarray]:
        self._t += 1

        dW_t = _delta_params_vec(global_weights_after, global_weights_before)
        dev = dW_t.device

        if global_update_vec is None:
            local_updates = [_delta_params_vec(li, global_weights_before) / lr for li in local_weights]
            g_t = torch.stack(local_updates, dim=0).mean(dim=0)
        else:
            g_t = global_update_vec.to(dev)

        if self._last_global_update is None:
            dG_t = torch.zeros_like(g_t)
        else:
            dG_t = g_t - self._last_global_update

        self._dW_hist.append(dW_t)
        self._dG_hist.append(dG_t)
        if len(self._dW_hist) > self.N:
            self._dW_hist.pop(0)
            self._dG_hist.pop(0)

        self._last_global_update = g_t

        v = dW_t
        hv = self._hvp_ls(v)
        d_list: List[float] = []
        predicted_cache: Dict[int, Tensor] = {}

        for uid, li in zip(chosen_users, local_weights):
            g_i_t = _delta_params_vec(li, global_weights_before) / lr
            g_prev = self._last_client_update.get(uid, torch.zeros_like(g_i_t))
            g_hat = g_prev + hv
            d = (g_hat - g_i_t).norm(p=2)
            d_list.append(float(d.item()))
            predicted_cache[uid] = g_i_t

        d_vec = torch.tensor(d_list, dtype=torch.float32, device=dev)
        denom = d_vec.norm(p=2)
        if denom.item() == 0.0:
            d_norm = torch.zeros_like(d_vec)
        else:
            d_norm = d_vec / denom

        self._d_norm_window.append(d_norm.detach().cpu())
        if len(self._d_norm_window) > self.N:
            self._d_norm_window.pop(0)

        for uid in chosen_users:
            self._last_client_update[uid] = predicted_cache[uid]

        scores = torch.stack(self._d_norm_window, dim=0).mean(dim=0)
        scores_np = scores.cpu().numpy()

        if self._t >= self.start_iter and len(self._d_norm_window) >= max(2, min(self.N, 3)):
            labels = self._cluster_via_gap(scores)
            if labels.max() == 0:
                malicious = []
            else:
                means = [scores[labels == k].mean().item() for k in range(labels.max() + 1)]
                mal_cluster = int(np.argmax(means))
                malicious = [uid for uid, lb in zip(chosen_users, labels.tolist()) if lb == mal_cluster]
        else:
            malicious = []

        return malicious, scores_np

    # ---------------- Gap Statistics + 1D k-means ----------------
    def _cluster_via_gap(self, scores: Tensor) -> Tensor:
        x = scores.detach().cpu().float().view(-1, 1).numpy()
        n = x.shape[0]
        if n <= 1:
            return torch.zeros(n, dtype=torch.long)

        xmin, xmax = float(np.min(x)), float(np.max(x))
        if xmin == xmax:
            return torch.zeros(n, dtype=torch.long)

        def kmeans_1d(x_, k, iters=50):
            rng = np.random.default_rng(0)
            centers = rng.choice(x_.reshape(-1), size=k, replace=False)
            for _ in range(iters):
                dists = np.abs(x_ - centers.reshape(1, -1))
                labels = dists.argmin(axis=1)
                new_centers = np.array(
                    [x_[labels == j].mean() if np.any(labels == j) else centers[j] for j in range(k)]
                )
                if np.allclose(new_centers, centers):
                    break
                centers = new_centers
            W = 0.0
            for j in range(k):
                cluster = x_[labels == j]
                if cluster.size > 0:
                    W += np.sum((cluster - centers[j]) ** 2)
            return labels, centers, float(W + 1e-12)

        logW = []
        labels_cache = []
        for k in range(1, self.Kmax + 1):
            lb, _, Wk = kmeans_1d(x, k)
            labels_cache.append(lb)
            logW.append(np.log(Wk))
        logW = np.array(logW)

        logW_ref = np.zeros((self.B, self.Kmax), dtype=float)
        rng = np.random.default_rng(1)
        for b in range(self.B):
            u = rng.uniform(low=xmin, high=xmax, size=(n, 1)).astype(np.float32)
            for k in range(1, self.Kmax + 1):
                _, _, Wk = kmeans_1d(u, k)
                logW_ref[b, k - 1] = np.log(Wk)

        gap = logW_ref.mean(axis=0) - logW
        sdk = np.sqrt(1 + 1.0 / self.B) * logW_ref.std(axis=0, ddof=1)

        k_hat = 1
        for k in range(1, self.Kmax):
            if gap[k - 1] - (gap[k] - sdk[k]) >= 0:
                k_hat = k
                break
            k_hat = self.Kmax

        final_labels = labels_cache[k_hat - 1]
        return torch.tensor(final_labels, dtype=torch.long)


def FGNV(
    w_locals: List[Dict[str, torch.Tensor]],
    w_glob_before: Dict[str, torch.Tensor],
    chosenUsers: List[int],
    learning_rate: float,
    device: Optional[torch.device] = None,
) -> List[int]:
    grad = calculate_gradients(w_glob_before, w_locals, learning_rate)  # List[Dict[str, Tensor]]
    C = len(grad)
    if C == 0:
        return []

    keys = list(grad[0].keys())
    P = len(keys)

    if device is None:
        for k in keys:
            if torch.is_tensor(grad[0][k]) and torch.is_floating_point(grad[0][k]):
                device = grad[0][k].device
                break
    if device is None:
        device = torch.device("cpu")

    G = torch.empty((C, P), device=device, dtype=torch.float32)

    for i in range(C):
        gi = grad[i]
        for p, k in enumerate(keys):
            t = gi[k]
            if torch.is_floating_point(t):
                G[i, p] = float(t.reshape(-1).norm(p=2).item())
            else:
                G[i, p] = 0.0

    eps = 1e-12
    inv_mean = (1.0 / (G + eps)).mean(dim=0)  # (P,)
    B = G * inv_mean  # (C,P)

    med = B.median(dim=0).values
    upper = med * 1.04
    lower = med * 0.96
    suspicious_mask = (B > upper) | (B < lower)

    counts = suspicious_mask.sum(dim=1)
    max_count = int(counts.max().item())
    if max_count < 1:
        return []

    malicious_idx = torch.nonzero(counts == max_count, as_tuple=False).view(-1).tolist()
    return [chosenUsers[i] for i in malicious_idx]
