from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
from click import progressbar
import jax
import numpy as np
import pdb
import sys

from .pate import compute_logpr_answered, compute_logpr_answered_fair
from .pate import compute_logq_gnmax
from .pate import compute_logq_multilabel_pate
from .pate import compute_rdp_data_dependent_gnmax
from .pate import compute_rdp_data_dependent_gnmax_no_upper_bound
from .pate import compute_rdp_data_dependent_threshold
from .pate import compute_rdp_data_independent_multilabel
from .pate import rdp_to_dp
from .pate import calculate_fairness_gaps

from autodp import rdp_acct

from jax import numpy as jnp
from jax.experimental.host_callback import id_print, call



# from autodp import rdp_bank


def analyze_results(votes, max_num_query, dp_eps):
    print('max_num_query;', max_num_query)
    dp_eps_items = []
    # eps were added to the sum of previous epsilons - subtract the value
    # to get single epsilons.
    dp_eps_items.append(dp_eps[0])
    for i in range(1, len(dp_eps)):
        dp_eps_items.append(dp_eps[i] - dp_eps[i - 1])
    dp_eps_items = np.array(dp_eps_items)
    avg_dp_eps = np.mean(dp_eps_items)
    print('avg_dp_eps;', avg_dp_eps)
    print('min_dp_eps;', np.min(dp_eps_items))
    print('median_dp_eps;', np.median(dp_eps_items))
    print('mean_dp_eps;', np.mean(dp_eps_items))
    print('max_dp_eps;', np.max(dp_eps_items))
    print('sum_dp_eps;', np.sum(dp_eps_items))
    print('std_dp_eps;', np.std(dp_eps_items))

    # Sort votes in ascending orders.
    sorted_votes = np.sort(votes, axis=-1)
    # Subtract runner-up votes from the max number of votes.
    gaps = sorted_votes[:, -1] - sorted_votes[:, -2]

    assert np.all(gaps > 0)
    print('min gaps;', np.min(gaps))
    print('avg gaps;', np.mean(gaps))
    print('median gaps;', np.median(gaps))
    print('max gaps;', np.max(gaps))
    print('sum gaps;', np.sum(dp_eps_items))
    print('std gaps;', np.std(dp_eps_items))

    # aggregate
    unique_gaps = np.unique(np.sort(gaps))
    gap_eps = {}
    print('gap;mean_eps')
    for gap in unique_gaps:
        mean_eps = dp_eps_items[gaps == gap].mean()
        gap_eps[gap] = mean_eps
        print(f'{gap};{mean_eps}')

    return gap_eps, gaps


def analyze_multiclass_confident_gnmax(
        votes, threshold, sigma_threshold, sigma_gnmax, budget, delta, file, log=print,
        show_dp_budget='disable', args=None):
    """
    Analyze how the pre-defined privacy budget will be exhausted when answering
    queries using the Confident GNMax mechanism.

    Args:
        votes: a 2-D numpy array of raw ensemble votes, with each row
        corresponding to a query.
        threshold: threshold value (a scalar) in the threshold mechanism.
        sigma_threshold: std of the Gaussian noise in the threshold mechanism.
        sigma_gnmax: std of the Gaussian noise in the GNMax mechanism.
        budget: pre-defined epsilon value for (eps, delta)-DP.
        delta: pre-defined delta value for (eps, delta)-DP.
        file: for logs.
        show_dp_budget: show the current cumulative dp budget.
        args: all args of the program

    Returns:
        max_num_query: when the pre-defined privacy budget is exhausted.
        dp_eps: a numpy array of length L = num-queries, with each entry
            corresponding to the privacy cost at a specific moment.
        partition: a numpy array of length L = num-queries, with each entry
            corresponding to the partition of privacy cost at a specific moment.
        answered: a numpy array of length L = num-queries, with each entry
            corresponding to the expected number of answered queries at a
            specific moment.
        order_opt: a numpy array of length L = num-queries, with each entry
            corresponding to the order minimizing the privacy cost at a
            specific moment.
    """
    max_num_query = 0

    def compute_partition(order_opt, eps, orders, rdp_eps_threshold_curr, rdp_eps_total_curr, delta):
        """Analyze how the current privacy cost is divided."""
        idx = jnp.searchsorted(orders, order_opt)
        rdp_eps_threshold = rdp_eps_threshold_curr[idx]
        rdp_eps_gnmax = rdp_eps_total_curr[idx] - rdp_eps_threshold
        p = jnp.array([rdp_eps_threshold, rdp_eps_gnmax,
                      -jnp.log(delta) / (order_opt - 1)])
        # assert sum(p) == eps
        # Normalize p so that sum(p) = 1
        return p / eps

    # RDP orders.
    orders = np.concatenate((np.arange(2, 100, .5),
                             np.logspace(np.log10(100), np.log10(1000),
                                         num=200)))
    # Number of queries
    n = len(votes)
    # All cumulative results
    dp_eps = jnp.zeros(n)
    partition = [None] * n
    order_opt = jnp.full(n, np.nan, dtype=float)
    answered = jnp.zeros(n, dtype=float)
    # Current cumulative results
    rdp_eps_threshold_curr = np.zeros(len(orders))
    rdp_eps_total_curr = np.zeros(len(orders))
    

    rdp_to_dp_jitted = jax.jit(rdp_to_dp)
    def run_vote(progress, v):
        max_num_query, answered_curr, rdp_eps_threshold_curr, rdp_eps_total_curr = progress
        logpr = compute_logpr_answered(threshold, sigma_threshold, v)
        rdp_eps_threshold = compute_rdp_data_dependent_threshold(
            logpr, sigma_threshold, orders)
        logq = compute_logq_gnmax(v, sigma_gnmax)
        rdp_eps_gnmax = compute_rdp_data_dependent_gnmax(
            logq, sigma_gnmax, orders)
        rdp_eps_total = rdp_eps_threshold + jnp.exp(logpr) * rdp_eps_gnmax
        # Evaluate E[(rdp_eps_threshold + Bernoulli(pr) * rdp_eps_gnmax)^2]
        # Update current cumulative results.
        rdp_eps_threshold_curr += rdp_eps_threshold
        rdp_eps_total_curr += rdp_eps_total
        pr_answered = jnp.exp(logpr)
        answered_curr += pr_answered
        # Update all cumulative results.
        dp_eps, order_opt = rdp_to_dp_jitted(orders, rdp_eps_total_curr, delta)
        partition = compute_partition(order_opt, dp_eps, orders, rdp_eps_threshold_curr, rdp_eps_total_curr, delta)
        # Verify if the pre-defined privacy budget is exhausted.
        max_num_query = jax.lax.cond(dp_eps <= budget, max_num_query, lambda x: x+1, max_num_query, lambda x: x)
        return [max_num_query, answered_curr, rdp_eps_threshold_curr, rdp_eps_total_curr], [order_opt, dp_eps, pr_answered, partition]
        
    # Iterating over all queries
    progress, output = jax.lax.scan(run_vote, 
                [0, 0, rdp_eps_threshold_curr, rdp_eps_total_curr], votes, length=len(votes))
    max_num_query, _, _, _= progress
    order_opt, dp_eps, answered, partition = output

    return max_num_query, dp_eps, partition, answered, order_opt

def analyze_multiclass_confident_fair_gnmax(votes, sensitives, threshold, fair_threshold, sigma_threshold, sigma_fair_threshold, sigma_gnmax, budget, delta, num_sensitive_attributes = 2, num_classes=10, minimum_group_count=50, log=print):
    """
    Analyze how the pre-defined privacy budget will be exhausted when answering
    queries using the Confident GNMax mechanism.

    Args:
        votes: a 2-D numpy array of raw ensemble votes, with each row
        corresponding to a query.
        threshold: threshold value (a scalar) in the threshold mechanism.
        sigma_threshold: std of the Gaussian noise in the threshold mechanism.
        sigma_gnmax: std of the Gaussian noise in the GNMax mechanism.
        budget: pre-defined epsilon value for (eps, delta)-DP.
        delta: pre-defined delta value for (eps, delta)-DP.
        file: for logs.
        show_dp_budget: show the current cumulative dp budget.
        args: all args of the program

    Returns:
        max_num_query: when the pre-defined privacy budget is exhausted.
        dp_eps: a numpy array of length L = num-queries, with each entry
            corresponding to the privacy cost at a specific moment.
        partition: a numpy array of length L = num-queries, with each entry
            corresponding to the partition of privacy cost at a specific moment.
        answered: a numpy array of length L = num-queries, with each entry
            corresponding to the expected number of answered queries at a
            specific moment.
        order_opt: a numpy array of length L = num-queries, with each entry
            corresponding to the order minimizing the privacy cost at a
            specific moment.
    """
    max_num_query = 0

    def compute_partition(order_opt, eps, orders, rdp_eps_threshold_curr, rdp_eps_total_curr, delta):
        """Analyze how the current privacy cost is divided."""
        idx = jnp.searchsorted(orders, order_opt)
        rdp_eps_threshold = rdp_eps_threshold_curr[idx]
        rdp_eps_gnmax = rdp_eps_total_curr[idx] - rdp_eps_threshold
        p = jnp.array([rdp_eps_threshold, rdp_eps_gnmax,
                      -jnp.log(delta) / (order_opt - 1)])
        # assert sum(p) == eps
        # Normalize p so that sum(p) = 1
        return p / eps

    # RDP orders.
    orders = np.concatenate((np.arange(2, 100, .5),
                             np.logspace(np.log10(100), np.log10(1000),
                                         num=200)))
    # Number of queries
    n = len(votes)
    # All cumulative results
    dp_eps = jnp.zeros(n)
    partition = [None] * n
    order_opt = jnp.full(n, np.nan, dtype=float)
    answered = jnp.zeros(n, dtype=float)
    # Current cumulative results
    rdp_eps_threshold_curr = np.zeros(len(orders))
    rdp_eps_total_curr = np.zeros(len(orders))
    
    sensitive_group_count = np.zeros(num_sensitive_attributes)
    sensitive_group_count = np.zeros(shape=(num_sensitive_attributes))
    # Note the shape of the positive counter. In k-class classification problem, we have shape: num_classes x num_sensitive_attributes
    per_class_pos_classified_group_count =  np.zeros(shape=(num_classes, num_sensitive_attributes))

    rdp_to_dp_jitted = jax.jit(rdp_to_dp)
    def run_vote(progress, v_and_sensitive):
        max_num_query, answered_curr, rdp_eps_threshold_curr, rdp_eps_total_curr, sensitive_group_count, per_class_pos_classified_group_count = progress
        v, sensitive = v_and_sensitive[:-1], v_and_sensitive[-1]

        # Selector one-hot vectors for the sensitive feature and the predicted class 
        sensitive_one_hot_over_sensitives = (np.arange(num_sensitive_attributes) == sensitive)
        prediction_one_hot_over_classes = (np.arange(num_classes) == np.argmax(v))

        # Calculate of the new (tentative) gaps if the answered
        _per_class_pos_classified_group_count =  per_class_pos_classified_group_count + \
                                prediction_one_hot_over_classes[:, None].dot(sensitive_one_hot_over_sensitives[:, None].T)
        _sensitive_group_count = sensitive_group_count + 1 * sensitive_one_hot_over_sensitives
        _new_gaps = calculate_fairness_gaps(_sensitive_group_count, _per_class_pos_classified_group_count)

        # (present) group gap
        _group_tentative_new_gap = _new_gaps.dot(sensitive_one_hot_over_sensitives)
        # id_print(_new_gaps)
        # (for comparison) calculate the probability of answering the query using only PATE analysis
        # pate_logpr = compute_logpr_answered(threshold, sigma_threshold, v)

        # hard decision version (no noising)
        # if sensitive_group_count[sensitive.astype(int)] < minimum_group_count:
        #     fairpate_logpr = pate_logpr 
        # elif _group_tentative_new_gap < fair_threshold:
        #     log(f"z={sensitive}, this_group_tentative_new_gap: {_group_tentative_new_gap} < fair_threshold: {fair_threshold}")
        #     fairpate_logpr = pate_logpr
        # else:
        #     log(f"z={sensitive}, this_group_tentative_new_gap: {_group_tentative_new_gap} > fair_threshold: {fair_threshold}")
        #     fairpate_logpr = -np.inf
        # logpr = pate_logpr

        # Calculate the probability of answering using fairPATE analysis
        fairpate_logpr = compute_logpr_answered_fair(threshold, fair_threshold, sigma_threshold, sigma_fair_threshold, v, _group_tentative_new_gap)
        logpr = fairpate_logpr # or pate_logpr to disable fairPATE analysis
        # id_print(logpr)
        
        # update counts (probabilistically)
        sensitive_group_count = sensitive_group_count + jnp.exp(logpr) * sensitive_one_hot_over_sensitives
        per_class_pos_classified_group_count = per_class_pos_classified_group_count + \
                            jnp.exp(logpr) * prediction_one_hot_over_classes[:, None].dot(sensitive_one_hot_over_sensitives[:, None].T)
        
        # id_print(sensitive_group_count)
        # re-calcualte definitive (and probabilistic) gaps
        new_gaps = calculate_fairness_gaps(sensitive_group_count, per_class_pos_classified_group_count)
        # id_print(new_gaps)

        rdp_eps_threshold = compute_rdp_data_dependent_threshold(logpr, sigma_threshold, orders)
        # id_print(rdp_eps_threshold)

        logq = compute_logq_gnmax(v, sigma_gnmax)
        # id_print(logq)
        rdp_eps_gnmax = compute_rdp_data_dependent_gnmax(logq, sigma_gnmax, orders)
        rdp_eps_total = rdp_eps_threshold + jnp.exp(logpr) * rdp_eps_gnmax
        # id_print(rdp_eps_total.var())
        # Evaluate E[(rdp_eps_threshold + Bernoulli(pr) * rdp_eps_gnmax)^2]
        # Update current cumulative results.
        rdp_eps_threshold_curr += rdp_eps_threshold
        rdp_eps_total_curr += rdp_eps_total
        pr_answered = jnp.exp(logpr)
        answered_curr += pr_answered
        # Update all cumulative results.
        dp_eps, order_opt = rdp_to_dp(orders, rdp_eps_total_curr, delta)
        # id_print(dp_eps)
        partition = compute_partition(order_opt, dp_eps, orders, rdp_eps_threshold_curr, rdp_eps_total_curr, delta)
        # Verify if the pre-defined privacy budget is exhausted.
        max_num_query = jax.lax.cond(dp_eps <= budget, max_num_query, lambda x: x+1, max_num_query, lambda x: x)
        return [max_num_query, answered_curr, rdp_eps_threshold_curr, rdp_eps_total_curr, sensitive_group_count, per_class_pos_classified_group_count], [order_opt, dp_eps, answered_curr, partition, new_gaps, pr_answered]     
    
    votes_and_sensitives = jnp.c_[votes, sensitives]

    # Iterating over all queries
    progress, output = jax.lax.scan(run_vote, 
                [0, 0, rdp_eps_threshold_curr, rdp_eps_total_curr, sensitive_group_count, per_class_pos_classified_group_count], votes_and_sensitives, length=len(votes))
    max_num_query, answered_curr, _, _, sensitive_group_count, pos_prediction_one_hot = progress
    order_opt, dp_eps, answered, partition, gaps, pr_answered = output

    return max_num_query, dp_eps, partition, answered, order_opt, sensitive_group_count, pos_prediction_one_hot, answered_curr, gaps, pr_answered