from __future__ import annotations
import math
import numpy as np
import torch

from utils.register import Register
from utils.tensor_utils import flatten_named_tensors, restore_tensor

flat_aggregators = Register()

@torch.no_grad()
@flat_aggregators('mean')
def _mean(
    flat_client_updates: torch.Tensor, n_samples: list[int]=[],
    **kwargs,
)->tuple[torch.Tensor, list[int]]:
    if len(n_samples) == 0:
        n_clients = len(flat_client_updates)
        weights = (torch.ones(n_clients) / n_clients).cuda()
    else:
        n_samples_tensor = torch.Tensor(n_samples)
        n_total_samples = n_samples_tensor.sum()
        weights = (n_samples_tensor / n_total_samples).cuda()
    
    flat_agg_update = (flat_client_updates * weights.view(-1, 1)).sum(dim=0)
    
    n_clients = len(flat_client_updates)
    candidate_idxs = [i for i in range(n_clients)]
    return flat_agg_update, candidate_idxs

@torch.no_grad()
@flat_aggregators('median')
def _median(
    flat_client_updates: torch.Tensor,
    **kwargs,
) -> tuple[dict[str, torch.Tensor], None]:
    flat_agg_update, _ = flat_client_updates.median(dim=0)
    
    candidate_idxs = None
    return flat_agg_update, candidate_idxs

@torch.no_grad()
@flat_aggregators('rbtm')
def _rbtm(
    flat_client_updates: torch.Tensor, n_attackers: int,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    n_clients = len(flat_client_updates)
    n_trim = min((n_clients - 1) // 2, n_attackers)
    # shortcut
    if n_trim == 0:
        agg_update = flat_client_updates.mean(dim=0)
    else:
        sorted_param_updates, sorted_idxs = flat_client_updates.sort(dim=0)
        agg_update = sorted_param_updates[n_trim:-n_trim].mean(dim=0)
    
    candidate_idxs = None
    
    return agg_update, candidate_idxs

@torch.no_grad()
def get_squared_distances(
    tensors: torch.Tensor,
)->torch.Tensor:
    t = tensors[None, :]
    dist = torch.cdist(t, t)[0]
    squared_dist = dist.square()
    return squared_dist

@torch.no_grad()
@flat_aggregators('gmedian')
def geometric_median(
    flat_client_updates: torch.Tensor,
    budget: int=3,
    lower_threshold: float=1e-6,
    **kwargs,
):
    n_clients = len(flat_client_updates)
    e = torch.ones(n_clients).cuda()
    agg_update = flat_client_updates.mean(dim=0)
    
    for _ in range(budget):
        distances: torch.Tensor = (flat_client_updates - agg_update).norm(dim=-1)
        weights = e / distances.clamp_min(lower_threshold)
        agg_update = (flat_client_updates * weights.view(-1, 1)).sum(dim=0)
    
    candidate_idxs = None

    return agg_update, candidate_idxs

@torch.no_grad()
@flat_aggregators('krum')
def _krum(
    flat_client_updates: torch.Tensor, n_attackers: int,
    n_candidates: int=1,
    **kwargs,
)->tuple[torch.Tensor, list[int]]:
    n_clients = len(flat_client_updates)
    n_attackers = min(n_attackers, (n_clients - 3) // 2)
    if n_candidates is None:
        n_candidates = n_clients - n_attackers
    
    squared_distances = get_squared_distances(tensors=flat_client_updates)
    n_closest = n_clients - 2 - n_attackers
    trimmed_distances, _ = squared_distances.topk(n_closest, dim=1, largest=False)
    scores = trimmed_distances.sum(dim=1)
    candidate_idxs = scores.argsort()[:n_candidates].tolist()
    candidate_updates = flat_client_updates[candidate_idxs]
    agg_update = candidate_updates.mean(dim=0)

    return agg_update, candidate_idxs

@torch.no_grad()
@flat_aggregators('bulyan')
def _bulyan(
    flat_client_updates: torch.Tensor, n_attackers: int,
    **kwargs,
)->tuple[torch.Tensor, None]:
    n_clients = len(flat_client_updates)

    n_attackers_mkrum = min(n_attackers, (n_clients-3) // 4)
    # n_attackers_mkrum = min(n_attackers, (n_clients - 3) // 2)
    n_mkrum_candidates = n_clients - 2 * n_attackers_mkrum
    _, mkrum_candidate_idxs = _krum(flat_client_updates=flat_client_updates, n_attackers=n_attackers, n_candidates=n_mkrum_candidates)
    flat_mkrum_candidate_updates = flat_client_updates[mkrum_candidate_idxs]
    
    n_attackers_bulyan = min(n_attackers, (n_clients-3) // 4)
    # beta = n_clients - 4 * n_attackers_bulyan
    beta = n_mkrum_candidates - 2 * n_attackers_bulyan
    
    update_med, _ = flat_mkrum_candidate_updates.median(dim=0)
    sort_idx = (flat_mkrum_candidate_updates - update_med).abs().argsort(dim=0)
    trim_votes = sort_idx[:beta]
    agg_update = flat_mkrum_candidate_updates.gather(dim=0, index=trim_votes).mean(dim=0)

    candidate_idxs = None
    return agg_update, candidate_idxs

@torch.no_grad()
@flat_aggregators('dnc')
def _divide_and_conquer(
    flat_client_updates: torch.Tensor, n_attackers: int,
    n_iters: int=1, filtering_fraction: float=2.0,
    n_sampled_coordinates: int=None,
    **kwargs,
) -> tuple[torch.Tensor, list[int]]:
    n_clients = len(flat_client_updates)
    logical_idxs = torch.ones(n_clients).bool().cuda()

    n_coordinates = flat_client_updates.size(-1)
    if n_sampled_coordinates is None:
        n_sampled_coordinates = n_coordinates
    else:
        n_sampled_coordinates = min(n_sampled_coordinates, n_coordinates)

    n_filtered_updates = min(math.ceil(n_attackers * filtering_fraction), n_clients - 1)

    for iter_num in range(n_iters):
        if n_sampled_coordinates == n_coordinates:
            sampled_updates = flat_client_updates
        else:
            sampled_coordinates = torch.randperm(n=n_coordinates)[:n_sampled_coordinates].cuda()
            sampled_updates = flat_client_updates[:, sampled_coordinates]
        avg_sampled_update = sampled_updates.mean(dim=0)
        centered_sampled_updates = sampled_updates - avg_sampled_update
        _, _, V = torch.svd(centered_sampled_updates)
        v = V[:, 0]
        score = (centered_sampled_updates @ v).square()
        _, filtered_idxs = score.topk(k=n_filtered_updates)
        logical_idxs[filtered_idxs] = False
    candidate_idxs = logical_idxs.nonzero(as_tuple=False).view(-1).tolist()
    agg_update = flat_client_updates[candidate_idxs].mean(dim=0)
    
    return agg_update, candidate_idxs


def split(n_params: int, n_groups: int) -> list[np.ndarray]:
    shuffled_idxs = np.arange(n_params)
    partition_sz = n_params // n_groups
    res_num = n_params % n_groups
    param_idxs_lt = []
    for partition_idx in range(n_groups):
        if partition_idx < res_num:
            partition_idxs = shuffled_idxs[partition_idx * (partition_sz + 1):(partition_idx + 1) * (partition_sz + 1)]
        else:
            partition_idxs = shuffled_idxs[partition_idx * partition_sz:(partition_idx + 1) * partition_sz]
        param_idxs_lt.append(partition_idxs)

    return param_idxs_lt
 
@torch.no_grad()
def get_outlier_scores(
    tensors: torch.Tensor, center_tensor: torch.Tensor,
):
    '''
    Args:
        tensors: (N, D)
        center_tensor: (D, )
    '''
    distance = (tensors - center_tensor).square().sum(dim=-1).sqrt()
    # avoid overflow
    scores = (distance / 1000).cpu()
    return scores

@torch.no_grad()
def get_candidate_idxs(
    outlier_scores: torch.Tensor,
    n_candidates: int=None, n_attackers:int=None, 
) -> list[int]:
    if n_candidates is None:
        n_clients = len(outlier_scores)
        n_candidates = n_clients - n_attackers
    _, candidate_idxs_tensor = outlier_scores.topk(n_candidates, largest=False)
    candidate_idxs = candidate_idxs_tensor.tolist()
    return candidate_idxs

@torch.no_grad()
def aggregate(
    client_updates: list[dict[str, torch.Tensor]], n_samples: list[int]=[],
    n_attackers: int=0,
    agg_type: str='mean', n_subvectors: int=0,
    **kwargs,
):
    if n_subvectors > 0:
        flat_client_udpates, restore_info = flatten_named_tensors(client_updates)
        n_params = flat_client_udpates.size(-1)
        param_idxs_lt = split(n_params, n_groups=n_subvectors)

        n_clients = len(client_updates)
        outlier_scores = torch.zeros(n_clients)
        for param_idxs in param_idxs_lt:
            flat_update_partition = flat_client_udpates[:, param_idxs]
            flat_agg_update_partition, _ =  flat_aggregators[agg_type](
                flat_client_updates=flat_update_partition, n_samples=n_samples,
                n_attackers=n_attackers,
                **kwargs,
            )
            partition_outlier_scores = get_outlier_scores(flat_update_partition, flat_agg_update_partition)
            outlier_scores += partition_outlier_scores
        candidate_idxs = get_candidate_idxs(outlier_scores, n_attackers=n_attackers)
            
        candidate_updates = [client_updates[i] for i in candidate_idxs]
        flat_candidate_udpates, restore_info = flatten_named_tensors(candidate_updates)
        flat_agg_update = flat_candidate_udpates.mean(dim=0)
        agg_update = restore_tensor(flat_agg_update, restore_info)
    else:
        flat_client_udpates, restore_info = flatten_named_tensors(client_updates)
        flat_agg_update, candidate_idxs =  flat_aggregators[agg_type](
            flat_client_updates=flat_client_udpates, n_samples=n_samples,
            n_attackers=n_attackers,
            **kwargs,
        )
        agg_update = restore_tensor(flat_agg_update, restore_info)

    return agg_update, candidate_idxs