import math
import numpy as np
from poisson_binomial import PoissonBinomial
from collections import defaultdict
from scipy.spatial.distance import cdist
# from sklearn.metrics.pairwise import manhattan_distances

class Homogenization:
    def __init__(self):
        self.supported_options = ["MinExp", "ProdExp", "ProdExp_rootk"]
        self.threshold_limits = {
            "ProdExp": None,
            "ProdExp_rootk": None,
            "MinExp": (0, 0),
        }

    def fix_distribution(self, p):
        # Verify p is close to a distribution
        assert all(-0.0001 <= p) and abs(sum(p) - 1.0) < 0.0001
        for index, value in enumerate(p):
            p[index] = max(value, 0)
        excess = sum(p) - 1.0

        for index, value in enumerate(p):
            if value > 10 * excess:
                p[index] = value - excess
                excess = 0.0
        if excess > 0.0:
            print("Failed to fix distribution: unfixed excess of {}".format(excess))

        assert all(0.0 <= p)
        return p

    def fail(self, rejections, num_interactions, t, threshold_type, strict=False):
        if threshold_type == "percent":
            if strict:
                success_threshold = math.floor(t * num_interactions)
            else:
                float_threshold = t * num_interactions
                lower_threshold, upper_threshold = math.floor(
                    float_threshold
                ), math.ceil(float_threshold)
                if lower_threshold == upper_threshold:
                    success_threshold = float_threshold
                else:
                    p = float_threshold - lower_threshold
                    # probability should satisfy strict inequalities based on if part
                    assert 0 < p < 1
                    success_threshold = np.random.choice(
                        [lower_threshold, upper_threshold], p=[1 - p, p]
                    )
                success_threshold = int(success_threshold)

        elif threshold_type == "absolute":
            success_threshold = t

        rejection_threshold = num_interactions - success_threshold

        if isinstance(rejections, int):
            assert rejections in range(num_interactions + 1)
            return rejections >= rejection_threshold

        assert len(rejections) == (num_interactions + 1)
        # Verify rejections is a distribution
        assert all(-0.0001 <= rejections) and abs(sum(rejections) - 1.0) < 0.0001
        return sum(rejections[rejection_threshold:])
    """ 
    Description:
    This function simple generates all alpha and delta combinations for a given range of alpha and delta values and a given step size for alpha.
    Alpha is the fraction of hard examples. Delta is the percent different from the original model error rate.

    # Parameters:
     - Alpha range is a tuple controlling the lower and upper bounds of alpha. 
     - Alpha step is the step size for alpha. 
     - Deltas is a list of floats or tuples floats for deltas. Each float or tuple in deltas is a value of delta to try; 
        if it is a tuple, it represents the percent different from the original model error rate for each model in the order of the models in expected_failure_probs (see generate_easy_hard_error_rates for more details)
     - Specified max error rates is a list of floats where each float represents the error rate of the worst model in a system. Each float in the list is a value to test. These max error rates will be used to calculate delta values, so they're just another way for the user to specify delta values.
     - Observed worst error rate is the error rate of the worst model in the observed system. This does not need to be provided if specified_max_error_rates is None, but it's required if specified_max_error_rates is not None. This is used to calculate delta values from the specified max error rates.
    """
    def generate_alpha_delta_combinations(self, alpha_range, alpha_step, deltas, specified_max_error_rates=None, observed_worst_error_rate=None):
        all_alpha_delta_combinations = set()
        alpha_lower_bound, alpha_upper_bound = alpha_range

        if specified_max_error_rates is not None:
            assert observed_worst_error_rate is not None
            for error_rate in specified_max_error_rates:
                implied_delta = (error_rate / observed_worst_error_rate) - 1
                deltas.append(implied_delta)

        for delta_tup in deltas: #(delta1, delta2, delta3 )
            alpha = alpha_lower_bound
            if type(delta_tup) is not tuple:
                delta_tup = (delta_tup, ) #convert delta to a tuple if it is not already a tuple
            while alpha <= alpha_upper_bound:
                assert alpha >= 0 and alpha <= 1
                if alpha > 0 and alpha < 1:
                    all_alpha_delta_combinations.add((round(alpha, 3), delta_tup))
                alpha += alpha_step
        all_alpha_delta_combinations.add((0, (0, ))) #this will recover expected distribution under simple assumption of independnece and uniformly distributed difficutly
        return all_alpha_delta_combinations

    """
    Description:
    This function calculates the expected failure probabilities of each model for different alpha and delta values.
    Alpha represents the fraction of hard examples. Delta represents the percent different from the original model error rate

    Parameters:
    expected_failure_probs: A list of the expected failure probabilities of each model
    alpha: The fraction of hard examples; should be a float between 0 and 1
    delta: Delta can be a float or a tuple of floats. If it is a float, it represents the percent difference from the original model error rate for all models. If it is a tuple, it represents the percent different from the original model error rate for each model in the order of the models in expected_failure_probs
            in the event that expected_failure+probs is longer than delta_tup, the last value in delta_tup will be used for the remaining models
    """
    def generate_easy_hard_error_rates(self, expected_failure_probs, alpha, delta_tup):
        # deltas = [tuple(delta) for delta in deltas]
        # if type(delta_tup) is not tuple:
        #     delta_tup = (delta_tup, ) #convert delta to a tuple if it is not already a tuple
        hard_factors = [1 + delta for delta in delta_tup]
        while len(hard_factors) < len(expected_failure_probs):
            hard_factors.append(1 + delta_tup[-1]) # in the event that there are fewer specified deltas than there are model error rates, use the last delta in delta_tup as the hard factor for the remaining models
        hard_factors = hard_factors[:len(expected_failure_probs)] #in the event that there are more deltas specified than models, truncate hard_factors
        
        hard_failure_probs = [p * hard_factors[i] for i, p in enumerate(expected_failure_probs)]
        
        easy_factors = []
        for hard_factor in hard_factors:
            try: 
                easy_factor = (1 - alpha * hard_factor) / (1 - alpha)
            except ZeroDivisionError:
                easy_factor = 0
            easy_factors.append(easy_factor)

        # print(f'alpha,delta == {alpha, delta}, easy factor == {easy_factor}, hard factor == {hard_factor}')
        easy_failure_probs = [p * easy_factors[i] for i, p in enumerate(expected_failure_probs)]
        if (any([p > 1.001 or p < -0.001 for p in hard_failure_probs])): #test if any probability greather than 1
            print(f"Delta value of {delta_tup}, alpha value of {alpha} resulted in an invalid hard distribution")
            print(f'original error rates == {expected_failure_probs}')
            print("hard error rates:", hard_failure_probs)
            raise ValueError("Invalid hard distribution")
        if (any([p > 1.001 or p < -0.001 for p in easy_failure_probs])): #test if any probability less than 0
            print(f"Delta value of {delta_tup}, alpha value of {alpha} resulted in an invalid easy distribution")
            print(f'original error rates == {expected_failure_probs}')
            print("easy error rates:", easy_failure_probs)
            raise ValueError("Invalid easy distribution")
        return (easy_failure_probs, hard_failure_probs)

    """
    Description:
    This function calculates the expected failure probabilities of each model for different alpha and delta values using the Poisson Binomial distribution.
    Alpha represents the fraction of hard examples. Delta represents the percent different from the original model error rate.

    Parameters:
    all_hard_error_rates: A dictionary containing the hard error rates for each model and each alpha and delta combination.
    all_easy_error_rates: A dictionary containing the easy error rates for each model and each alpha and delta combination.
    interactions: A list indicating which models a user has interacted with 

    Returns:
    A dictionary containing expected probability mass functions (PMFs) calculated using the Poisson binomial distribution. The dictionary is indexed by a tuple of (alpha, delta) values."""
    def generate_expected_pmf(self, all_hard_error_rates, all_easy_error_rates, interactions):
        expected_pmfs = {}
        saved_hard_pmfs = {}
        for alpha, delta in all_hard_error_rates.keys():
            bernoullis_hard = [all_hard_error_rates[(alpha, delta)][i] for i in interactions]
            bernoullis_easy = [all_easy_error_rates[(alpha, delta)][i] for i in interactions]
            try:
                hard_pmf = saved_hard_pmfs[delta]
            except KeyError:
                hard_pmf = np.array(PoissonBinomial(bernoullis_hard).pmf)
                saved_hard_pmfs[delta] = hard_pmf
            easy_pmf = np.array(PoissonBinomial(bernoullis_easy).pmf)

            expected_pmf = alpha * hard_pmf + (1-alpha) * easy_pmf 
            try:
                expected_pmf = self.fix_distribution(expected_pmf)
            except AssertionError:
                print(f'easy = {bernoullis_easy}, hard = {bernoullis_hard}')
            expected_pmfs[(alpha, delta)] = expected_pmf

        # possible_alpha_delta_values -= invalid_alpha_delta_values #Remove alpha_delta values that result in invalid distributions

        return expected_pmfs
    

    """
	I: Interaction matrix that encodes if a user (j) has interacted with a model (i). A[j, i] = 1 when user has interacted with model else 0
	R: Rejection matrix that encodes the outcome from user j interacting wtih model i. R[j, i] = 1 when a user j has been rejected/failed by a model i
	expected_rejection_probs: length k numpy array of probability of failure conditional on interaction
	N: The total number of users .
	k: The total number of models 
    expected_rejections_probs: A length k numpy array of the probability of failure for each model; eqiuvalent to the error rate of each model.
	t: Threshold for how many models can be correct to still be counted a systemic failure. A value of 0 means that no model can be correct. 
		t will be either a 0-1 bounded ratio or an absolute value based on threshold_type. 
	threshold_type: 'absolute' or 'percent'
    alpha_range: A tuple of the form (lower_bound, upper_bound) that specifies the range of alpha values to be tested. Alpha defies the fraction of hard examples.
    alpha_step: The step size for alpha values to be tested.
    deltas: A list of delta values to be tested. Delta represents how much harder the hard examples are than the original model error rates such that the hard error rate is (1 + delta) * original error rate. 
            Each element of deltas can be a float or a tuple; if it is a float, then the same delta value will be applied to all models. If it is a tuple, then element i of the tuple defines the delta_factor for model $h_i$. 
            If the length of the tuple is less than the number of models, then the last element of the tuple will be used as the delta value for all remaining models $h_j$ where j \geq |delta tuple|.
	metrics: List of metrics to calculate. Should be subset of ("ProdExp, MinExp")
	verbose: If true, will output histograms of observed failures, expected failures, and sampled failures if sample_failures is True
	skip_singletons: If true, ignore data points that interact with only a single system, since the notion of systemic failure is arguably degenerate.
	sample_failures: If true, will sample rejections from the expected failure distribution. 

	Returns a dict of dicts with the following scheme 
    { 
        metric_name : {
            "Numerator": N, 
            "Denominator": D, 
            "Homogenization": H },
        "histograms": {
            "observed": observed_failure_histogram,
            "expected": expected_failure_histogram,
        },
        "Sampled Failure Rate": sampled_failure_rate,
    }
	"""

    def measure_homogenization(
        self,
        I,
        R,
        N,
        k,
        expected_rejection_probs,
        t,
        threshold_type,
        metrics=["ProdExp"],
        alpha_range = (0, 0), # inclusive (lower_bound, upper_bound)
        alpha_step = .2, 
        deltas = [],
        specified_max_error_rates = None,
        verbose=True,
        skip_singletons=True,
        sample_failures=False,
    ):
        assert alpha_step > .0001
        for metric in metrics:
            assert (
                metric in self.supported_options
            ), "Provided metric not in supported options: {}".format(
                self.supported_options
            )
            threshold_limit = self.threshold_limits[metric]
            if threshold_limit is not None:
                assert (
                    t >= threshold_limit[0] and t <= threshold_limit[1]
                ), "{} supports threshold value range of {}".format(
                    metric, threshold_limit
                )

        max_interactions = int(np.amax(np.sum(I, axis=1)))
        sampled_failure_histograms = {}
        expected_failure_histograms = {}
        sampled_failures_dict = {}
        expected_failures_dict = defaultdict(dict)

        worst_observed_error_rate = max(expected_rejection_probs)
        all_alpha_delta_combinations = self.generate_alpha_delta_combinations(alpha_range, alpha_step, deltas, specified_max_error_rates, worst_observed_error_rate)
        all_easy_error_rates = {}
        all_hard_error_rates = {}
        invalid_alpha_delta_values = set()

        for alpha, delta_tup in all_alpha_delta_combinations:            
            # Metrics can vary in how they treat 'expected' failures. We need to keep track of each metric's notion of expected failures separately, and we use a dict to store this
            try:
                easy_error_rates, hard_error_rates = self.generate_easy_hard_error_rates(expected_rejection_probs, alpha, delta_tup)
                all_easy_error_rates[(alpha, delta_tup)] = easy_error_rates
                all_hard_error_rates[(alpha, delta_tup)] = hard_error_rates
            except ValueError:
                invalid_alpha_delta_values.add((alpha, delta_tup))
                continue
            sampled_failure_histograms[(alpha, delta_tup)] = [0] * (max_interactions + 1)
            expected_failure_histograms[(alpha, delta_tup)] = [0] * (max_interactions + 1)
            sampled_failures_dict[(alpha, delta_tup)] = 0
            
            for metric in metrics:
                expected_failures_dict[metric][(alpha, delta_tup)] = 0

        all_alpha_delta_combinations -= invalid_alpha_delta_values #Remove alpha_delta values that result in invalid distributions
        observed_failure_histogram = [0] * (max_interactions + 1)
        observed_failures = 0


        included_individuals = 0

        # Loop through all individuals. Calculate how many rejections they get on the models they interact with.
        # Calculate the expected failure distribution over the models they interact with. Add observed rejection
        # to observed_failures and expected rejection to expected_failure.
        for j in range(N):
            interactions = {i for i in range(k) if I[j, i]}
            num_interactions = len(interactions)
            num_observed_rejections = int(sum(R[j, i] for i in interactions))
            assert num_interactions <= max_interactions
            assert num_observed_rejections <= num_interactions

            bernoullis = [expected_rejection_probs[i] for i in interactions]
            expected_pmfs = self.generate_expected_pmf(all_hard_error_rates, all_easy_error_rates, interactions)


            if verbose:
                observed_failure_histogram[
                    num_interactions - num_observed_rejections
                ] += 1
            
            if not skip_singletons or num_interactions > 1:
                observed_failure = self.fail(
                    num_observed_rejections, num_interactions, t, threshold_type
                )
                observed_failures += observed_failure
                included_individuals += 1
            

                for (alpha, delta), expected_pmf in expected_pmfs.items():
                    if sample_failures:
                        num_sampled_rejections = int(
                                np.random.choice(num_interactions + 1, 1, p=expected_pmf)
                            )
                    if verbose:  # Only construct histograms if verbose
                        sampled_failure_histogram = sampled_failure_histograms[(alpha, delta)]
                        expected_failure_histogram = expected_failure_histograms[(alpha, delta)]
                        num_possible_outcomes = len(expected_pmf)
                        assert num_possible_outcomes == num_interactions + 1
                        for index in range(num_possible_outcomes):
                            expected_failure_histogram[
                                num_possible_outcomes - index - 1
                            ] += expected_pmf[index]
                            
                            if sample_failures:
                                sampled_failure_histogram[
                                    num_interactions - num_sampled_rejections
                                ] += 1
                    if not skip_singletons or num_interactions > 1:
                        for metric in metrics:  # Metric choice determines the denominator
                            if metric == "MinExp":
                                sorted_rejection_rates = np.sort(bernoullis)
                                min_error_rate = sorted_rejection_rates[t]
                                expected_failure = min_error_rate
                            if metric == "ProdExp" or metric == "ProdExp_rootk":
                                expected_failure = self.fail(
                                    expected_pmf, num_interactions, t, threshold_type
                                )

                            expected_failures_dict[metric][(alpha, delta)] += expected_failure


                    if sample_failures:
                        sampled_failure = self.fail(
                            num_sampled_rejections, num_interactions, t, threshold_type
                        )
                        sampled_failures_dict[(alpha, delta)] += sampled_failure

        if included_individuals == 0:
            return None

        numerator = observed_failures / included_individuals

        result = defaultdict(dict)

        if verbose:
            result["Histograms"] = {
                "Observed Failure Histogram": observed_failure_histogram,
            }
            
        result['Easy Error Rates'] = all_easy_error_rates
        result['Hard Error Rates']= all_hard_error_rates

        if verbose:
            result["Histograms"]["Expected Failure Histogram"] = expected_failure_histograms
        for alpha, delta in all_alpha_delta_combinations:
            if verbose:
                expected_failure_distribution = expected_failure_histograms[(alpha, delta)]
                
                np.testing.assert_allclose(np.sum(expected_failure_distribution), included_individuals) 
                np.testing.assert_allclose(np.sum(observed_failure_histogram), included_individuals)

                observed_normalized = np.divide(observed_failure_histogram, included_individuals)
                expected_normalized = np.divide(expected_failure_distribution, included_individuals)
                observed_expected_l1_distance = sum(abs(observed_normalized - expected_normalized))

                if 'distance' not in result["Histograms"]: 
                    result["Histograms"]['distance'] = {}
                result["Histograms"]["distance"][(alpha, delta)] = observed_expected_l1_distance

                if sample_failures:
                    sampled_failure_rate = sampled_failures_dict[(alpha, delta)] / included_individuals
                    result[f"Sampled Failure Rate a={alpha} d={delta}"] = sampled_failure_rate
                    result["Histograms"][
                        f"Sampled Failure Histogram ({alpha}, {delta})"
                    ] = sampled_failure_histogram[(alpha,delta)]
            for metric in metrics:
                print(metric, alpha, delta)
                denominator = expected_failures_dict[metric][(alpha, delta)] / included_individuals
                if metric == "ProdExp_rootk":
                    denominator = denominator ** (1 / k)
                H = numerator / denominator
                metric_result = {
                    "Homogenization": H,
                    "Numerator": numerator,
                    "Denominator": denominator,
                }
                result[metric][(alpha, delta)] = metric_result
        return result
