from decimal import *

import numpy as np

import torch
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection

from .base import _BaseAggregator


class Krum(_BaseAggregator):
    r"""
    This script implements Multi-KRUM algorithm.

    Blanchard, Peva, Rachid Guerraoui, and Julien Stainer.
    "Machine learning with adversaries: Byzantine tolerant gradient descent."
    Advances in Neural Information Processing Systems. 2017.
    """

    def __init__(self, n, f, m, p_norm, dim_reduce_method, n_reduced_dims, fixed_point):
        self.n = n
        self.f = f
        self.m = m
        self.p_norm = p_norm
        self.n_reduced_dims = n_reduced_dims
        self.dim_reduce_method = dim_reduce_method
        self.fixed_point = fixed_point
        super(Krum, self).__init__()

    def __call__(self, inputs):
        if self.dim_reduce_method is not None:
            distances = self.pairwise_euclidean_distances(
                self.random_projection(inputs)
            )
        else:
            distances = self.pairwise_euclidean_distances(inputs)

        return self.multi_krum(inputs, distances)

    def __str__(self):
        return "Krum (m={})".format(self.m)

    def _compute_scores(self, distances, i):
        """Compute scores for node i.

        Arguments:
            distances {dict} -- A dict of dict of distance. distances[i][j] = dist. i, j starts with 0.
            i {int} -- index of worker, starting from 0.
            n {int} -- total number of workers
            f {int} -- Total number of Byzantine workers.

        Returns:
            float -- krum distance score of i.
        """
        if self.fixed_point:
            s = [Decimal(distances[j][i])**2 for j in range(i)] + \
                [Decimal(distances[i][j])**2 for j in range(i + 1, self.n)]

        else:
            s = [distances[j][i] ** 2 for j in range(i)] + [
                 distances[i][j] ** 2 for j in range(i + 1, self.n)
            ]
        _s = sorted(s)[: self.n - self.f - 2]
        return sum(_s)

    def multi_krum(self, inputs, distances):
        """Multi_Krum algorithm

        Arguments:
            distances {dict} -- A dict of dict of distance. distances[i][j] = dist. i, j starts with 0.
            n {int} -- Total number of workers.
            f {int} -- Total number of Byzantine workers.
            m {int} -- Number of workers for aggregation.

        Returns:
            list -- A list indices of worker indices for aggregation. length <= m
        """
        if self.n < 1:
            raise ValueError(
                "Number of workers should be positive integer. Got {}.".format(self.f)
            )

        if self.m < 1 or self.m > self.n:
            raise ValueError(
                "Number of workers for aggregation should be >=1 and <= {}. Got {}.".format(
                    self.m, self.n
                )
            )

        if 2 * self.f + 2 > self.n:
            raise ValueError("Too many Byzantine workers: 2 * {} + 2 >= {}.".format(self.f, self.n))

        for i in range(self.n - 1):
            for j in range(i + 1, self.n):
                if distances[i][j] < 0:
                    raise ValueError(
                        "The distance between node {} and {} should be non-negative: Got {}.".format(
                            i, j, distances[i][j]
                        )
                    )

        # compute_scores will convert distances to fixed point (if specified)
        scores = [(i, self._compute_scores(distances, i)) for i in range(self.n)]
        sorted_scores = sorted(scores, key=lambda x: x[1])
        top_m_indices = list(map(lambda x: x[0], sorted_scores))[:self.m]

        if self.fixed_point:
            # convert to fixed point
            decimal_inputs = np.array([[Decimal(x) for x in inputs[i].tolist()] for i in top_m_indices])

            # carry out server operations
            mean_inputs = np.mean(decimal_inputs, axis=0)

            # convert back to floating point
            value = torch.tensor(mean_inputs.tolist(), dtype=torch.float32).to("cuda:0")
        else:
            value = sum(inputs[i] for i in top_m_indices) / self.m

        return value

    def _compute_euclidean_distance(self, v1, v2):
        return (v1 - v2).norm(p=self.p_norm)

    def pairwise_euclidean_distances(self, vectors):
        """Compute the pairwise euclidean distance.

        Arguments:
            vectors {list} -- A list of vectors.

        Returns:
            dict -- A dict of dict of distances {i:{j:distance}}
        """
        n = len(vectors)
        vectors = [v.flatten() for v in vectors]

        distances = {}
        for i in range(n - 1):
            distances[i] = {}
            for j in range(i + 1, n):
                if self.fixed_point:
                    distances[i][j] = Decimal(self._compute_euclidean_distance(vectors[i], vectors[j]).item())**2
                else:
                    distances[i][j] = self._compute_euclidean_distance(vectors[i], vectors[j]) ** 2
        return distances

    def random_projection(self, inputs):
        if self.dim_reduce_method == 'gaussian':
            random_matrix = GaussianRandomProjection(n_components=self.n_reduced_dims)
        elif self.dim_reduce_method == 'sparse':
            random_matrix = SparseRandomProjection(n_components=self.n_reduced_dims)
        else:
            raise NotImplementedError(self.dim_reduce_method)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        stacked_inputs = torch.stack(inputs)
        reduced = random_matrix.fit_transform(stacked_inputs.cpu().numpy())
        reduced_tensor = torch.from_numpy(reduced).float().to(device)

        reduced_inputs = list(reduced_tensor)
        assert reduced_inputs[0].shape[0] == self.n_reduced_dims

        return reduced_inputs
