import argparse
import json
import matplotlib.pyplot as plt
from math import comb
import numpy as np
import os

def pass_at_k(n, c, k):
    """
    Calculates the pass@k probability.
    """
    if k == 0 or c == 0:
        return 0.0
    if n - c < k:
        return 1.0
    return 1.0 - (comb(n - c, k) / comb(n, k))

def compute_pass_at_k_curve(accuracies, n, max_k=None):
    """
    Computes a single pass@k curve for a given list of accuracies.
    """
    pass_k_values = []
    if max_k is None:
        max_k = n
    for k in range(1, max_k + 1):
        total = 0
        for acc in accuracies:
            c = int(round(n * acc))
            total += pass_at_k(n, c, k)
        pass_k_values.append(total / len(accuracies))
    return pass_k_values

def bootstrap_pass_at_k_curve(accuracies, n, max_k=None, n_bootstrap=1000):
    """
    Computes the pass@k curve with bootstrap estimates of variance.

    Args:
        accuracies (list): A list of accuracy scores.
        n (int): Total number of generations.
        max_k (int, optional): Maximum value of k. Defaults to n.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.

    Returns:
        tuple: A tuple containing the mean curve (np.ndarray) and the 
               standard deviation curve (np.ndarray).
    """
    if max_k is None:
        max_k = n

    bootstrapped_curves = []
    accuracies_np = np.array(accuracies)
    n_accuracies = len(accuracies_np)

    print(f"Running {n_bootstrap} bootstrap samples...")
    for _ in range(n_bootstrap):
        # Resample the accuracies with replacement
        resampled_accuracies = np.random.choice(
            accuracies_np, size=n_accuracies, replace=True
        )

        # Compute the pass@k curve for the resampled data
        curve = compute_pass_at_k_curve(resampled_accuracies, n, max_k)
        bootstrapped_curves.append(curve)

    # Convert list of curves to a NumPy array for vectorized operations
    bootstrapped_curves = np.array(bootstrapped_curves)

    # Calculate the mean and standard deviation across all bootstrap samples
    mean_curve = np.mean(bootstrapped_curves, axis=0)
    std_dev_curve = np.std(bootstrapped_curves, axis=0)

    return mean_curve, std_dev_curve

def load_json_accuracies(path):
    with open(path, 'r') as f:
        data = json.load(f)
    key = next(iter(data))
    return data[key]

def main():
    parser = argparse.ArgumentParser(description="Plot Pass@k curves from JSON accuracy files.")
    parser.add_argument("json_paths", nargs="+", help="Paths to JSON files containing accuracy lists.")
    parser.add_argument("--n", type=int, default=128, help="Total number of generations (default: 128).")
    parser.add_argument("--max-k", type=int, help="Maximum value of k to compute Pass@k for (default: same as n).")
    parser.add_argument("--labels", nargs="+", help="Labels for each curve (default: file name).")
    parser.add_argument("--output", type=str, help="Path to save the output plot (e.g., plot.png).")
    parser.add_argument("--bootstrap-samples", type=int, default=1000, 
                        help="Number of bootstrap samples for variance estimation. Set to 0 to disable. (default: 1000).")
    # New argument to filter by difficulty
    parser.add_argument("--keep-most-difficult", type=int, metavar="K",
                        help="Keep only the K most difficult problems based on the 'base' model's performance. Requires one filename to contain 'base'.")
    args = parser.parse_args()

    # Load all accuracies into a dictionary first
    all_accuracies = {path: load_json_accuracies(path) for path in args.json_paths}

    # Filter problems based on difficulty if the option is provided
    if args.keep_most_difficult is not None:
        if args.keep_most_difficult <= 0:
            print("Error: --keep-most-difficult must be a positive integer.")
            return

        # Identify the base model (filename contains "base", case-insensitive)
        base_model_path = next((path for path in args.json_paths if "base" in os.path.basename(path).lower()), None)

        if not base_model_path:
            print("Error: A model with 'base' in its filename is required when using --keep-most-difficult.")
            return

        print(f"Using '{os.path.basename(base_model_path)}' to determine the most difficult problems.")

        base_accuracies = np.array(all_accuracies[base_model_path])
        num_problems = len(base_accuracies)

        if args.keep_most_difficult > num_problems:
            print(f"Warning: K ({args.keep_most_difficult}) is larger than the number of problems ({num_problems}). Using all problems.")
        else:
            # Find the indices of the K problems with the lowest accuracy
            sorted_indices = np.argsort(base_accuracies)
            difficult_indices = sorted_indices[:args.keep_most_difficult]

            # Create a new dictionary for the filtered accuracies
            filtered_all_accuracies = {}
            for path, accs in all_accuracies.items():
                filtered_all_accuracies[path] = np.array(accs)[difficult_indices].tolist()

            # Overwrite the main accuracies dictionary with the filtered data
            all_accuracies = filtered_all_accuracies
            print(f"Filtered data to keep the {len(difficult_indices)} most difficult problems.")

    plt.figure(figsize=(10, 6))

    for idx, path in enumerate(args.json_paths):
        # `accuracies` will be the filtered list if the option was used
        accuracies = all_accuracies[path]
        print(f"Plotting for '{os.path.basename(path)}' with {len(accuracies)} data points.")
        label = args.labels[idx] if args.labels and idx < len(args.labels) else os.path.basename(path)

        if args.bootstrap_samples > 0:
            # Calculate mean and std dev using bootstrapping
            mean_curve, std_dev_curve = bootstrap_pass_at_k_curve(
                accuracies, args.n, args.max_k, args.bootstrap_samples
            )

            x_values = np.arange(1, len(mean_curve) + 1)
            # Plot the mean curve
            plt.plot(x_values, mean_curve, label=label)
            # Plot the shaded variance region (1 standard deviation)
            plt.fill_between(
                x_values,
                mean_curve - std_dev_curve,
                mean_curve + std_dev_curve,
                alpha=0.2
            )
        else:
            # Original behavior without bootstrapping
            pass_k_curve = compute_pass_at_k_curve(accuracies, args.n, args.max_k)
            plt.plot(range(1, len(pass_k_curve) + 1), pass_k_curve, label=label)

    plt.title("Pass@k Curve with Bootstrap Variance")
    plt.xlabel("k")
    plt.ylabel("Pass@k")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    if args.output:
        plt.savefig(args.output)
        print(f"Plot saved to {args.output}")
    else:
        plt.show()

if __name__ == "__main__":
    main()