"""
This file implements the classical DP accountant by Kairouz, Oh, Viswanath
"""

import numpy as np


class DP_acct:
    """
    A class that keeps track of (eps,delta) of all mechanisms that got run so far.
    Then it allows delta => epsilon, and epsilon => delta queries.
    """

    #
    # DPlosses = []

    def __init__(self, disp=False):
        self.DPlosses = []
        self.eps_state1 = 0
        self.eps_state2 = 0
        self.eps_state3 = 0
        self.delta_state = 0
        self.delta_state2 = 0
        self.disp = disp

    def update_DPlosses(self, eps, delta):
        self.DPlosses.append([eps, delta])
        self.eps_state1 += eps
        self.eps_state2 += (np.exp(eps) - 1) * eps / (np.exp(eps) + 1)
        self.eps_state3 += eps ** 2
        self.delta_state += np.log(1 - delta)
        self.delta_state2 += delta
        # update the optimal DPlosses here?

    def get_eps_delta_naive(self):
        return self.eps_state1, self.delta_state2

    def get_minimum_possible_delta(self):
        return 1 - np.exp(self.delta_state)

    def get_eps(self, delta):
        """ make use of KOV'15 to calculate the composition for hetereogeneous mechanisms"""
        # TODO: further improve upon this with Mutagh & Vadhan's approximation algorithm
        assert (delta >= 0)
        if delta >= 1:
            return 0
        if delta is 0:  # asking for pure DP
            if self.delta_state2 is 0:
                return self.delta_state2
            else:
                return np.inf
                # if self.disp:
                #    print("Pure DP is not feabile. Choose non-zero delta")
                # return -1

        # 1-delta =  (1-deltatilde) exp(state)
        deltatilde = 1 - np.exp(np.log(1 - delta) - self.delta_state)
        if deltatilde <= 0:
            return np.inf
            # if self.disp:
            #    print("The chosen delta is not feasible. delta needs to be at least ", 1-np.exp(self.delta_state))
            # return -1

        eps1 = self.eps_state1
        eps2 = self.eps_state2 + (self.eps_state3 * 2 * np.log(
            np.exp(1) + self.eps_state3 ** 0.5 / deltatilde)) ** 0.5
        eps3 = self.eps_state2 + (
                    2 * self.eps_state3 * np.log(1 / deltatilde)) ** 0.5

        return np.minimum(np.minimum(eps1, eps2), eps3)
