import numpy as np
import pandas as pd

# Set seed for reproducibility
np.random.seed(42)

def argmax_with_tie_break(scores):
    """Standard argmax, breaking ties randomly."""
    # Find all indices that have the maximum score
    max_indices = np.flatnonzero(scores == scores.max())
    # Randomly choose one of these indices
    return np.random.choice(max_indices)

def inflated_argmax(scores, epsilon):
    """Inflated argmax (sargmax_epsilon)."""
    max_score = scores.max()
    return set(np.flatnonzero(scores >= max_score - epsilon))

def bagged_argmax(sampling_weights, B=50):
    """
    Bagged argmax. Simulates B bootstrap samples and aggregates by majority vote.

    Args:
        sampling_weights (np.array): The weights [w1, w2, w3] of the voter groups.
        B (int): The number of bootstrap samples to generate.

    Returns:
        int: The winning candidate (0, 1, or 2).
    """
    # Total number of voters to simulate in each bootstrap sample
    total_voters = int(np.sum(sampling_weights))

    # Probabilities for drawing a voter from each group
    probabilities = sampling_weights / total_voters

    # Generate B bootstrap samples at once using a multinomial distribution.
    # Each row is a new set of weights for a bootstrap sample.
    bootstrap_samples = np.random.multinomial(total_voters, probabilities, size=B)

    votes = []
    for sample_weights in bootstrap_samples:
        # For each bootstrap sample, find the winner
        scores = calculate_scores_exp1(sample_weights)
        winner = argmax_with_tie_break(scores)
        votes.append(winner)

    # Aggregate the votes and find the final winner by majority vote
    # np.bincount will count occurrences of 0, 1, 2.
    vote_counts = np.bincount(votes, minlength=3)

    # Return the candidate with the most votes, breaking ties randomly
    final_winner = argmax_with_tie_break(vote_counts)

    return final_winner


def calculate_scores_exp1(sampling_weights):
    """
    Simulates the aggregate scores (utilities) for A, B, C based on sampling weights.
    Hypotheses: 0=A, 1=B, 2=C.
    """
    U1 = np.array([3.0, 2.0, 1.0]) # Prefers A > B > C
    U2 = np.array([1.0, 3.0, 2.0]) # Prefers B > C > A
    U3 = np.array([2.0, 1.0, 3.0]) # Prefers C > A > B

    w1, w2, w3 = sampling_weights
    total_weight = w1 + w2 + w3

    # Avoid division by zero if weights are all zero
    if total_weight == 0:
        return np.zeros(3)

    U_agg = (w1 * U1 + w2 * U2 + w3 * U3) / total_weight
    return U_agg

def run_experiment_1_with_bagging(N_TRIALS=5000, N_BOOTSTRAPS=50, BASE_WEIGHT=100, PERTURBATION=1, INFLATION_EPSILON=0.01):
    """
    Runs the experiment comparing standard argmax, inflated argmax, and the new bagged argmax.
    """
    print(f"--- Running Experiment 1 (Stress Test with Bagging) ---")
    print(f"Parameters: N_TRIALS={N_TRIALS}, N_BOOTSTRAPS={N_BOOTSTRAPS}")
    results = []

    for _ in range(N_TRIALS):
        # 1. Generate the original dataset's scores
        W_S = BASE_WEIGHT + np.random.uniform(-0.5, 0.5, 3)
        Scores_S = calculate_scores_exp1(W_S)

        # 2. Generate the adjacent (perturbed) dataset's scores
        perturb_vector = np.zeros(3)
        perturb_idx = np.random.choice(3)
        perturb_dir = np.random.choice([-1, 1])
        perturb_vector[perturb_idx] = perturb_dir * PERTURBATION
        W_Sp = W_S + perturb_vector
        W_Sp = np.maximum(0.1, W_Sp) # Ensure weights are not negative
        Scores_Sp = calculate_scores_exp1(W_Sp)

        # 3. Get selections for each method
        # Standard Argmax
        Sel_A_S = argmax_with_tie_break(Scores_S)
        Sel_A_Sp = argmax_with_tie_break(Scores_Sp)

        # Inflated Argmax
        Sel_I_S = inflated_argmax(Scores_S, INFLATION_EPSILON)
        Sel_I_Sp = inflated_argmax(Scores_Sp, INFLATION_EPSILON)

        # Bagged Argmax
        Sel_B_S = bagged_argmax(W_S, B=N_BOOTSTRAPS)
        Sel_B_Sp = bagged_argmax(W_Sp, B=N_BOOTSTRAPS)

        # 4. Calculate stability for each method
        Stab_A = 1.0 if Sel_A_S == Sel_A_Sp else 0.0
        Stab_I = 1.0 if len(Sel_I_S.intersection(Sel_I_Sp)) > 0 else 0.0
        Stab_B = 1.0 if Sel_B_S == Sel_B_Sp else 0.0

        results.append({
            "Stab_Argmax": Stab_A,
            "Stab_Inflated": Stab_I,
            "Stab_Bagged": Stab_B,
            "Avg_Size_Inflated": len(Sel_I_S)
        })

    # 5. Aggregate and print results
    df = pd.DataFrame(results)
    avg_stab_A = df["Stab_Argmax"].mean()
    avg_stab_I = df["Stab_Inflated"].mean()
    avg_stab_B = df["Stab_Bagged"].mean()
    avg_size_I = df["Avg_Size_Inflated"].mean()

    print("\n--- Results ---")
    print(f"Average Stability (Standard Argmax):    {avg_stab_A:.4f}")
    print(f"Average Stability (Bagged Argmax):      {avg_stab_B:.4f}  <-- New Result")
    print(f"Average Stability (Inflated Argmax):    {avg_stab_I:.4f}")
    print(f"Average Output Size (Inflated):         {avg_size_I:.4f}")
    print("-----------------\n")

    return df

# Run the updated experiment
if __name__ == "__main__":
    results_df = run_experiment_1_with_bagging()
