#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math
from abc import ABC
from enum import IntEnum
from itertools import cycle
from typing import Iterator, List, Union

import torch
import numpy as np
from utils.minimize_helper import *

try:
    from skimage.filters import threshold_otsu as otsu
except ImportError:

    def otsu(*args, **kwargs) -> float:
        raise NotImplementedError("Install skimage!")


def _mean_plus_r_var(data: torch.Tensor, ratio: float = 0, **kwargs) -> float:
    """
    Caclulates mean + ratio x standard_deviation of the provided tensor
    and returns the larger of this value and the smallest element in
    the tensor (can happen when ratio is negative).

    Parameters
    ----------
    data: torch.Tensor
        Pytorch tensor containing the data on which the mean and stdv.
        is evaluated.

    ratio: float, optional
        Value of the scaling factor in the value calculated by the
        function.

    Returns
    -------
    float
        The result of the function.

    """
    return max(data.min().item(), data.mean().item() + ratio * data.std().item() + 1e-8)


def _pvalue(data: torch.Tensor, ratio: float = 0.25, **kwargs) -> torch.Tensor:
    """
    Finds the pth largest value in the tensor, where p = ratio x len(data).

    Parameters
    ----------
    data: torch.Tensor
        Pytorch tensor against which the function is evaluated.

    ratio: float, optional
        Value of the scaling factor in the value calculated by
        the function.

    Returns
    -------
    torch.Tensor
        Tensor of dimension (1,) with the result of the function.
    """
    cut = max(1, int(data.numel() * (1 - ratio)))
    return torch.kthvalue(data, cut)[0].item()


def _static(data: torch.Tensor, group: torch.Tensor, current_thresh: float, **kwargs) -> float:
    """
    Passes through the specified input ``current_threshold``.

    Parameters
    ----------
    data: torch.Tensor
        Pytorch tensor containing the data.

    current_thresh: float
        The threshold value.

    Returns
    -------
    float
        The threshold value.
    """
    return current_thresh


def _otsu(data: torch.Tensor, **kwargs) -> float:
    """
    Returns an intensity threshold for an image that separates it
    into backgorund and foreground pixels.

    The implementation uses Otsu's method, which assumes a GMM with
    2 components but uses some heuristic to maximize the variance
    differences. The input data is shaped into a 2D image for the
    purpose of evaluating the threshold value.

    Parameters
    ----------
    data: torch.Tensor
        Pytorch tensor containing the data.

    Returns
    -------
    float
        Threshold value determined via Otsu's method.
    """
    h = 2 ** int(1 + math.log2(data.shape[0]) / 2)
    fake_img = data.view(h, -1).cpu().numpy()
    return otsu(fake_img, h)


def _opt_q(data: torch.Tensor, group: torch.Tensor, **kwargs) -> float:
    idx = {}
    group_clip = []
    grp = torch.unique(group)
    quantile = 1 - math.sqrt(2. / math.pi) * kwargs.get('noise_multiplier') / math.e
    for g in grp:
        idx[g.item()] = torch.nonzero(group == g.item(), as_tuple=False).flatten()
        group_clip.append(np.quantile(data[idx[g.item()]].cpu(), quantile))

    if kwargs.get('device') != 'cpu':
        threshs = torch.cuda.FloatTensor(group.size())
    else:
        threshs = torch.FloatTensor(group.size())
    for g, c in zip(grp, group_clip):
        threshs[idx[g.item()]] = c

    return threshs


def _dpsgd_f(data: torch.Tensor, group: torch.Tensor, current_thresh, **kwargs) -> float:
    idx = {}
    group_m = {}
    group_clip = []
    grp, cnt = torch.unique(group, return_counts=True)
    for g in grp:
        idx[g.item()] = torch.nonzero(group == g.item(), as_tuple=False).flatten()
        group_m[g.item()] = torch.sum(torch.gt(data[idx[g.item()]], current_thresh).float())
    # m := |{i:|g_i|>C}|
    m = sum(group_m.values())
    b = sum(cnt)

    for g, gb in zip(grp, cnt):
        if m == 0:
            group_clip.append(current_thresh)
        else:
            group_clip.append(current_thresh * (1.0 + group_m[g.item()] / gb / m * b))

    if kwargs.get('device') != 'cpu':
        threshs = torch.cuda.FloatTensor(group.size())
    else:
        threshs = torch.FloatTensor(group.size())
    for g, c in zip(grp, group_clip):
        threshs[idx[g.item()]] = c

    return threshs


def _fairdp(data: torch.Tensor, group: torch.Tensor, **kwargs) -> float:
    idx = {}
    group_norm = {}
    l2_norm_clip = []
    grp = torch.unique(group)
    for g in grp:
        idx[g.item()] = torch.nonzero(group == g.item(), as_tuple=False).flatten()
        group_norm[g.item()] = data[idx[g.item()]].cpu().numpy()
        l2_norm_clip.append(np.median(group_norm[g.item()]))
    group_clip = compute_minimize(l2_norm_clip, kwargs.get('lr'), group_norm, kwargs.get('batch_size'),
                                  kwargs.get('nabla'), kwargs.get('noise_multiplier'), kwargs.get('param_numel'))

    if kwargs.get('device') != 'cpu':
        threshs = torch.cuda.FloatTensor(group.size())
    else:
        threshs = torch.FloatTensor(group.size())
    for g, c in zip(grp, group_clip):
        threshs[idx[g.item()]] = c

    return threshs


class ClippingMethod(IntEnum):
    STATIC = 0
    PVALUE = 1
    MEAN = 2
    GMM = 3
    OTSU = 4
    OPT_Q = 5
    DPSGD_F = 6
    FAIRDP = 7


_thresh_ = {
    ClippingMethod.STATIC: _static,
    ClippingMethod.PVALUE: _pvalue,
    ClippingMethod.MEAN: _mean_plus_r_var,
    ClippingMethod.OTSU: _otsu,
    ClippingMethod.OPT_Q: _opt_q,
    ClippingMethod.DPSGD_F: _dpsgd_f,
    ClippingMethod.FAIRDP: _fairdp
}


def _calculate_thresh_value(
        data: torch.Tensor,
        device: int,
        group: torch.Tensor,
        lr: float,
        batch_size: int,
        nabla: float,
        noise_multiplier: float,
        param_numel: int,
        current_thresh: float,
        clipping_mehod: ClippingMethod = ClippingMethod.STATIC,
        ratio: float = -1
) -> Union[float, torch.Tensor]:
    """
    Calculates a clipping threshold by looking at the layer norms
    of each example.

    Four methods are supported: static threshold, threshold calculated
    based on mean and variance of the norms, and threshold calculated
    based on percentile values of the norms.

    Parameters
    ----------
    data: torch.Tensor
        1-D tensor.
    current_thresh: float
        Value of the current threshold.
    clipping_method: ClippingMethod
        Enum value defining the clipping strategy. Current options are STATIC,
        PVALUE, MEAN, and OTSU.
    ratio: float
        Value that has different meaning for differnet strategies, it is the
        percentile parameter for PVALUE, and a multiplier for standard deviation
        for MEAN. It has no significance for OTSU and STATIC.

    Returns
    -------
    float
        Clipping threshold value
    """
    return _thresh_[clipping_mehod](data, device=device, group=group, lr=lr, batch_size=batch_size, nabla=nabla,
                                    noise_multiplier=noise_multiplier, param_numel=param_numel, ratio=ratio,
                                    current_thresh=current_thresh)


class NormClipper(ABC):
    """
    An abstract class to calculate the clipping factor
    """

    def calc_clipping_factors(
            self, norms: List[torch.Tensor], device: int, group: torch.Tensor, lr: float, batch_size: int, nabla: float,
            noise_multiplier: float, param_numel: int
    ) -> Union[List[torch.Tensor], Iterator[torch.Tensor]]:
        """
        Calculates the clipping factor(s) based on the given
        parameters. A concrete subclass must implement this.
        """
        pass

    @property
    def thresholds(self) -> torch.Tensor:
        """
        Depending on the type of clipper, returns threshold values
        that may be used in different ways.

        Returns
        -------
        torch.Tensor
            Tensor containing the threshold values
        """
        pass

    @property
    def is_per_layer(self) -> bool:
        """
        Depending on type of clipper, returns indicator as to whether
        different clipping is applied to each layer in the model.

        Returns
        -------
        bool
            Flag indicator as to whether different clipping is applied
            to each layer in the model.
        """
        pass


class ConstantFlatClipper(NormClipper):
    """
    A clipper that clips all gradients in such a way that their norm is
    at most equal to a specified value. This value is shared for all
    layers in a model. Note that the process of clipping really involves
    multiplying all gradients by a scaling factor. If this scaling factor
    is > 1.0, it is instead capped at 1.0. The net effect is that the final
    norm of the scaled gradients will be less than the specified value in
    such a case. Thus it is better to think of the specified value as an
    upper bound on the norm of final clipped gradients.
    """

    def __init__(self, flat_value: float):
        """
        Parameters
        ----------
        flat_value: float
            Constant value that is used to normalize gradients
            such that their norm equals this value before clipping.
            This threshold value is used for all layers.
        """
        self.flat_value = float(flat_value)

    def calc_clipping_factors(
            self, norms: List[torch.Tensor], device: int, group: torch.Tensor, lr: float, batch_size: int, nabla: float,
            noise_multiplier: float, param_numel: int
    ) -> Iterator[torch.Tensor]:
        """
        Calculates the clipping factor based on the given
        norm of gradients for all layers, so that the new
        norm of clipped gradients is at most equal to
        ``self.flat_value``.

        Parameters
        ----------
        norms: List[torch.Tensor]
            List containing a single tensor of dimension (1,)
            with the norm of all gradients.

        Returns
        -------
        Iterator[torch.Tensor]
            Tensor containing the single threshold value to
            be used for all layers.
        """
        # Expects a list of size one.
        if len(norms) != 1:
            raise ValueError(
                "Waring: flat norm selected but "
                f"received norm for {len(norms)} layers"
            )
        per_sample_clip_factor = self.flat_value / (norms[0] + 1e-6)
        # We are *clipping* the gradient, so if the factor is ever >1 we set it to 1
        per_sample_clip_factor = per_sample_clip_factor.clamp(max=1.0)  # pyre-ignore
        # return this clipping factor for all layers
        return cycle([per_sample_clip_factor])

    @property
    def thresholds(self) -> torch.Tensor:
        """
        Returns singleton tensor of dimension (1,) containing
        the common threshold value used for clipping all
        layers in the model.

        Returns
        -------
        torch.Tensor
            Threshold values
        """
        return torch.tensor([self.flat_value])

    @property
    def is_per_layer(self) -> bool:
        """
        Returns indicator as to whether different clipping is applied
        to each layer in the model. For this clipper, it is False.

        Returns
        -------
        bool
            Flag with value False
        """
        return False


class ConstantPerLayerClipper(NormClipper):
    """
    A clipper that clips all gradients in such a way that their norm is
    at most equal to a specified value. This value is specified for each
    layer in a model. Note that the process of clipping really involves
    multiplying all gradients by a scaling factor. If this scaling factor
    is > 1.0, it is instead capped at 1.0. The net effect is that the final
    norm of the scaled gradients will be less than the specified value in
    such a case. Thus it is better to think of the specified value as an
    upper bound on the norm of final clipped gradients.
    """

    def __init__(self, flat_values: List[float]):
        """
        Parameters
        ----------
        flat_values: List[float]
            List of values that is used to normalize gradients for each
            layer such that the norm equals the corresponding value
            before clipping.
        """
        self.flat_values = [float(fv) for fv in flat_values]

    def calc_clipping_factors(self, norms: List[torch.Tensor], device: int, group: torch.Tensor, lr: float,
                              batch_size: int, nabla: float, noise_multiplier: float, param_numel: int) -> List[
        torch.Tensor]:
        """
        Calculates separate clipping factors for each layer based on
        its corresponding norm of gradients, such that its new norm is
        at most equal to the flat value specified for that layer when
        instantiating the object of
        :class:`~opacus.utils.clipping.ConstantPerLayerClipper`.

        Parameters
        ----------
        norms: List[torch.Tensor]
            List containing the desired norm of gradients for each layer.

        Returns
        -------
        List[torch.Tensor]
            List of tensors, each containing a single value specifying the
            clipping factor per layer.
        """
        if len(norms) != len(self.flat_values) * 2 and len(self.flat_values) != 1:
            raise ValueError(
                f"{len(norms)} layers have provided norms but the "
                f"number of clipping thresholds is {len(self.flat_values)}"
            )

        self.flat_values = self.flat_values * (
            len(norms) if len(self.flat_values) == 1 else 1
        )

        clipping_factor = []
        for weight_norm, bias_norm, threshold in zip(norms[0::2], norms[1::2], self.flat_values):
            layer_norm = torch.sqrt(torch.square(weight_norm) + torch.square(bias_norm))
            per_sample_clip_factor = threshold / (layer_norm + 1e-6)
            clipping_factor.append(per_sample_clip_factor.clamp(max=1.0))
            clipping_factor.append(per_sample_clip_factor.clamp(max=1.0))
        return clipping_factor

    @property
    def thresholds(self) -> torch.Tensor:
        """
        Returns
        ----------
        torch.Tensor
            List of values that is used to normalize gradients
            for each layer such that the norm at most equals the
            corresponding value before clipping.
        """
        return torch.tensor(self.flat_values)

    @property
    def is_per_layer(self) -> bool:
        """
        Returns indicator as to whether different clipping is applied
        to each layer in the model. For this clipper, it is True.

        Returns
        -------
        bool
            Flag with value True
        """
        return True


class _Dynamic_Clipper_(NormClipper):
    """
    This is a generic clipper, that is in an experimental phase.
    The clipper uses different stats to find a clipping threshold
    based on the given per sample norms.

    Notes
    -----
        This clipper breaks DP guarantees [use only for experimentation]
    """

    def __init__(
            self,
            group: torch.Tensor,
            flat_values: List[float],
            clip_per_layer: bool = False,
            clipping_method: ClippingMethod = ClippingMethod.STATIC,
            ratio: float = 0.0
    ):
        """
        Parameters
        ----------
        flat_value: List[float]
            List of float values that is used to normalize gradients
            for each layer such that the norm equals the corresponding
            value before clipping.

        clip_per_layer: bool
            Flag indicating if a separate desired norm value is specified
            per layer or if a single value is shared for all.

        clipping_method: ClippingMethod
            Value in the enum ClippingMethod that specifies one of the
            currently supported clipping types.

        ratio: float
            Value that can be used to evaluate the clipping threshold
            for certain clipping types.
        """
        self.flat_values = [float(float_value) for float_value in flat_values]
        self.clip_per_layer = clip_per_layer
        if clipping_method != ClippingMethod.STATIC:
            print(
                "Warning! Current implementations of dynamic clipping "
                "are not privacy safe; Caclulated privacy loss is not "
                "indicative of a proper bound."
            )
        self.clipping_method = clipping_method
        self.ratio = ratio
        self.thresh = [0.0]

    def calc_clipping_factors(
            self, norms: List[torch.Tensor], device: int, group: torch.Tensor, lr: float, batch_size: int, nabla: float,
            noise_multiplier: float, param_numel: int
    ) -> Union[List[torch.Tensor], Iterator[torch.Tensor]]:
        """
        Calculates separate clipping factors for each layer based on
        stats such as a threshold determined by Otsu's method, combinations
        of mean and std. deviation, kth median value etc.

        This is experimental and does not guarantee privacy and is not recommended
        for production use.

        Parameters
        ----------
        norms: List[torch.Tensor]
            List containing the desired norm of gradients for each layer.

        Returns
        -------
        List[torch.Tensor] or Iterator[torch.Tensor]
            Singleton list specifying a common clippng factor for all layers,
            or an iterator of tensors specifying a clipping factor per layer
        """
        if len(self.flat_values) == 1:
            current_threshs = self.flat_values * (
                len(norms) if self.clip_per_layer else 1
            )
        clipping_factor = []

        if len(norms) != len(current_threshs):  # pyre-ignore
            raise ValueError(
                # pyre-fixme[6]: Expected `Sized` for 1st param but got `int`.
                f"Provided grad norm max's size {len(current_threshs)}"  # pyre-ignore
                f" does not match the number of layers {len(norms)}"
            )

        for norm, current_thresh in zip(norms, current_threshs):  # pyre-ignore
            thresh = _calculate_thresh_value(
                norm, device, group, lr, batch_size, nabla, noise_multiplier, param_numel, current_thresh,
                self.clipping_method, self.ratio
            )
            self.thresh = thresh
            per_sample_clip_factor = thresh / (norm + 1e-6)
            clipping_factor.append(per_sample_clip_factor.clamp(max=1.0))  # pyre-ignore
        return clipping_factor if self.is_per_layer else cycle(clipping_factor)

    @property
    def thresholds(self) -> torch.Tensor:
        """
        Returns
        ----------
        torch.Tensor
            Tensor of values that is used to normalize gradients
            for each layer such that the norm at most equals the
            corresponding value before clipping.
        """
        return torch.tensor(self.thresh)

    @property
    def is_per_layer(self) -> bool:
        """
        Returns indicator as to whether different clipping is applied
        to each layer in the model.

        Returns
        -------
        bool
            Value of the flag
        """
        return self.clip_per_layer
