import concurrent.futures
import os
import time
from dataclasses import dataclass
from enum import Enum
from math import ceil

import numpy as np
from matplotlib import pyplot as plt
from rich.console import Group
from rich.live import Live
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskID,
    TextColumn,
)
from rich.table import Table
from scipy.stats import sem
from tqdm.auto import tqdm

from research.wsl_ece.metric.distribution import (
    CIFAR10PredictionDistribution,
    DDI2013PredictionDistribution,
    MixNMatchDistribution,
    MNISTPredictionDistribution,
    SyntheticDistribution,
)
from research.wsl_ece.metric.ece import BinningStrategy


class DatasetNames(Enum):
    """
    Enum for dataset names used in experiments.
    """

    MIX_N_MATCH_LESS_CALIBRATED_CASE = "Synthetic: less calibrated case"
    MIX_N_MATCH_BETTER_CALIBRATED_CASE = "Synthetic: better calibrated case"
    MNIST = "MNIST"
    CIFAR10 = "CIFAR-10"
    DDI2013 = "DDI2013"


class ConvergenceExperimentRunner:
    """
    Orchestrates experiments to see convergence of PU-ECE and ECE to the ground truth TCE.
    """

    class Setting(Enum):
        """
        Enum for experiment settings.
        """

        PU_ECE_INFINITE_UNLABELED = "PU-ECE ($n_\\mathrm{{P}} = N$, $n_\\mathrm{{U}} = \\infty$)"
        PU_ECE_INFINITE_POSITIVE = "PU-ECE ($n_\\mathrm{{P}} = \\infty$, $n_\\mathrm{{U}} = N$)"
        PU_ECE_100_POSITIVE_N_UNLABELED = "PU-ECE ($n_\\mathrm{{P}} = 100$, $n_\\mathrm{{U}} = N$)"
        PU_ECE_1000_POSITIVE_N_UNLABELED = "PU-ECE ($n_\\mathrm{{P}} = 1000$, $n_\\mathrm{{U}} = N$)"
        PU_ECE_10000_POSITIVE_N_UNLABELED = "PU-ECE ($n_\\mathrm{{P}} = 10000$, $n_\\mathrm{{U}} = N$)"
        PU_ECE_N_POSITIVE_10N_UNLABELED = "PU-ECE ($n_\\mathrm{{P}} = N$, $n_\\mathrm{{U}} = 10N$)"
        PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_5PCT = (
            "PU-ECE ($n_\\mathrm{{P}} = N$, $n_\\mathrm{{U}} = 10N$), $\\pi'_\\mathrm{{P}} = 1.05\\pi_\\mathrm{{P}}$"
        )
        PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_5PCT = (
            "PU-ECE ($n_\\mathrm{{P}} = N$, $n_\\mathrm{{U}} = 10N$), $\\pi'_\\mathrm{{P}} = 0.95\\pi_\\mathrm{{P}}$"
        )
        PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_10PCT = (
            "PU-ECE ($n_\\mathrm{{P}} = N$, $n_\\mathrm{{U}} = 10N$), $\\pi'_\\mathrm{{P}} = 1.1\\pi_\\mathrm{{P}}$"
        )
        PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_10PCT = (
            "PU-ECE ($n_\\mathrm{{P}} = N$, $n_\\mathrm{{U}} = 10N$), $\\pi'_\\mathrm{{P}} = 0.9\\pi_\\mathrm{{P}}$"
        )
        BINNED_TCE_CBRT_N_BIN = "Binned TCE ($B = N^{1/3}$)"
        BINNED_TCE_CBRT_N_POSITIVE_BIN = "Binned TCE ($B = (N / \\pi_\\mathrm{{P}}^2)^{1/3}$)"
        BINNED_TCE_N_BIN = "Binned TCE ($B = N$)"
        ECE_CBRT_N_BIN = "ECE ($n = N$, $B = N^{1/3}$)"
        ECE_CBRT_N_POSITIVE_BIN = "ECE ($n = N$, $B = (N / \\pi_\\mathrm{{P}}^2)^{1/3}$)"

        def to_marker(self) -> str:
            """
            Converts the setting to a marker string for plotting.
            """
            match self:
                case self.PU_ECE_INFINITE_UNLABELED:
                    return "h"
                case self.PU_ECE_INFINITE_POSITIVE:
                    return "H"
                case self.PU_ECE_100_POSITIVE_N_UNLABELED:
                    return "1"
                case self.PU_ECE_1000_POSITIVE_N_UNLABELED:
                    return "2"
                case self.PU_ECE_10000_POSITIVE_N_UNLABELED:
                    return "3"
                case self.PU_ECE_N_POSITIVE_10N_UNLABELED:
                    return "4"
                case self.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_5PCT:
                    return "P"
                case self.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_5PCT:
                    return "X"
                case self.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_10PCT:
                    return "+"
                case self.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_10PCT:
                    return "x"
                case self.BINNED_TCE_CBRT_N_BIN:
                    return "v"
                case self.BINNED_TCE_CBRT_N_POSITIVE_BIN:
                    return "^"
                case self.BINNED_TCE_N_BIN:
                    return ">"
                case self.ECE_CBRT_N_BIN:
                    return "D"
                case self.ECE_CBRT_N_POSITIVE_BIN:
                    return "d"
            raise ValueError(f"Unknown setting: {self}")

    @dataclass
    class Result:
        """
        Data class to hold the result of a single experiment.
        Contains the setting, number of samples, mean ECE, standard deviation of ECE, and TCE.
        """

        setting: "ConvergenceExperimentRunner.Setting"
        n_samples: int
        mean_ece: float
        sem_ece: float
        tce: float
        mean_abs_bias: float
        sem_abs_bias: float
        lower_90ci_abs_bias: float
        upper_90ci_abs_bias: float

    @dataclass
    class Results:
        """
        Data class to hold the results of the experiments.
        Contains the setting, number of samples, mean ECE, standard deviation of ECE, and TCE.
        """

        setting: "ConvergenceExperimentRunner.Setting"
        n_samples: list[int]
        mean_ece: list[float]
        sem_ece: list[float]
        tce: float
        mean_abs_bias: list[float]
        sem_abs_bias: list[float]
        lower_90ci_abs_bias: list[float]
        upper_90ci_abs_bias: list[float]

        def append(self, result: "ConvergenceExperimentRunner.Result"):
            """
            Appends a single result to the results.
            """
            if self.setting != result.setting:
                raise ValueError("Cannot append result with different setting.")
            self.n_samples.append(result.n_samples)
            self.mean_ece.append(result.mean_ece)
            self.sem_ece.append(result.sem_ece)
            self.mean_abs_bias.append(result.mean_abs_bias)
            self.sem_abs_bias.append(result.sem_abs_bias)
            self.lower_90ci_abs_bias.append(result.lower_90ci_abs_bias)
            self.upper_90ci_abs_bias.append(result.upper_90ci_abs_bias)
            if self.tce != result.tce:
                raise ValueError(
                    f"Cannot append result with different TCE. Expected TCE: {self.tce}, got: {result.tce}"
                )
            # Sort the results by number of samples to ensure consistency
            sorted_indices = np.argsort(self.n_samples)
            self.n_samples = [self.n_samples[i] for i in sorted_indices]
            self.mean_ece = [self.mean_ece[i] for i in sorted_indices]
            self.sem_ece = [self.sem_ece[i] for i in sorted_indices]
            self.mean_abs_bias = [self.mean_abs_bias[i] for i in sorted_indices]
            self.sem_abs_bias = [self.sem_abs_bias[i] for i in sorted_indices]
            self.lower_90ci_abs_bias = [self.lower_90ci_abs_bias[i] for i in sorted_indices]
            self.upper_90ci_abs_bias = [self.upper_90ci_abs_bias[i] for i in sorted_indices]
            self.validate()

        def validate(self):
            """
            Validates the results to ensure all lists have the same length.
            Raises ValueError if the lengths are inconsistent.
            """
            n = len(self.n_samples)
            if not (len(self.mean_ece) == len(self.sem_ece) == len(self.mean_abs_bias) == len(self.sem_abs_bias) == n):
                raise ValueError("All result lists must have the same length.")

        @classmethod
        def from_single_result(cls, result: "ConvergenceExperimentRunner.Result"):
            """
            Creates a Results object from a single Result object.
            """
            return cls(
                setting=result.setting,
                n_samples=[result.n_samples],
                mean_ece=[result.mean_ece],
                sem_ece=[result.sem_ece],
                tce=result.tce,
                mean_abs_bias=[result.mean_abs_bias],
                sem_abs_bias=[result.sem_abs_bias],
                lower_90ci_abs_bias=[result.lower_90ci_abs_bias],
                upper_90ci_abs_bias=[result.upper_90ci_abs_bias],
            )

    def __init__(self, dataset_name: DatasetNames, t: float = 1.0, gmm_fit_mode="separate"):
        """
        Initializes the experiment runner with the specified dataset name.

        Args:
            dataset_name (DatasetNames): The name of the dataset to use for experiments.
            t (float): The temperature parameter for the synthetic distribution. Default is 1.0.
        """
        self.dataset_name = dataset_name
        self.distribution: SyntheticDistribution
        match dataset_name:
            case DatasetNames.CIFAR10:
                self.distribution = CIFAR10PredictionDistribution(gmm_fit_mode=gmm_fit_mode)
            case DatasetNames.MNIST:
                self.distribution = MNISTPredictionDistribution(gmm_fit_mode=gmm_fit_mode)
            case DatasetNames.MIX_N_MATCH_LESS_CALIBRATED_CASE:
                self.distribution = MixNMatchDistribution(beta_0=-0.5, beta_1=1.5)
            case DatasetNames.MIX_N_MATCH_BETTER_CALIBRATED_CASE:
                self.distribution = MixNMatchDistribution(beta_0=-0.2, beta_1=1.9)
            case DatasetNames.DDI2013:
                self.distribution = DDI2013PredictionDistribution(gmm_fit_mode=gmm_fit_mode)
        self.distribution.plot_distribution()  # Plot the distribution for visualization

    def get_single_ece_statistics(
        self, setting: Setting, n_samples: int, num_repeats: int, binning_strategy: BinningStrategy, seed: int = 42
    ):
        """
        Gets ECE or PU-ECE statistics for a single setting and sample size.
        This method runs an experiment to compute the Expected Calibration Error (ECE) or Positive-Unlabeled ECE
        (PU-ECE) based on the specified setting and number of samples. It repeats the experiment a specified number of
        times to obtain a mean and standard deviation of the ECE or PU-ECE values. The results are printed and returned
        as a dictionary.

        Args:
            setting (Settings): The experiment setting to use.
            n_samples (int): The number of samples to use for the experiment.
            num_repeats (int): The number of times to repeat the experiment.
            binning_strategy (BinningStrategy): The binning strategy to use for ECE calculation.
            seed (int): The random seed for reproducibility.
        Returns:
            dict: A dictionary containing the results of the experiment, including mean ECE, standard deviation, and
                  TCE.
        """
        if n_samples <= 0:
            raise ValueError("n_samples must be a positive integer.")
        n_bins: int | None = None
        n_positive: int | float = np.inf
        n_unlabeled: int | float = np.inf
        prior_estimation_error = 1.0
        match setting:
            case self.Setting.PU_ECE_INFINITE_UNLABELED:
                n_positive = n_samples
                n_unlabeled = np.inf
            case self.Setting.PU_ECE_INFINITE_POSITIVE:
                n_positive = np.inf
                n_unlabeled = n_samples
            case self.Setting.PU_ECE_100_POSITIVE_N_UNLABELED:
                n_positive = 100
                n_unlabeled = n_samples
            case self.Setting.PU_ECE_1000_POSITIVE_N_UNLABELED:
                n_positive = 1000
                n_unlabeled = n_samples
            case self.Setting.PU_ECE_10000_POSITIVE_N_UNLABELED:
                n_positive = 10000
                n_unlabeled = n_samples
            case self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED:
                n_positive = n_samples
                n_unlabeled = 10 * n_samples
            case self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_5PCT:
                n_positive = n_samples
                n_unlabeled = 10 * n_samples
                prior_estimation_error = 1.05  # Overestimate prior by 5%
            case self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_5PCT:
                n_positive = n_samples
                n_unlabeled = 10 * n_samples
                prior_estimation_error = 0.95  # Underestimate prior by 5%
            case self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_10PCT:
                n_positive = n_samples
                n_unlabeled = 10 * n_samples
                prior_estimation_error = 1.1  # Overestimate prior by 10%
            case self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_10PCT:
                n_positive = n_samples
                n_unlabeled = 10 * n_samples
                prior_estimation_error = 0.9  # Underestimate prior by 10%
            case self.Setting.BINNED_TCE_CBRT_N_BIN:
                n_positive = np.inf
                n_unlabeled = np.inf
                n_bins = ceil(np.cbrt(n_samples))  # Use cube root of n_samples for binning
            case self.Setting.BINNED_TCE_CBRT_N_POSITIVE_BIN:
                n_positive = np.inf
                n_unlabeled = np.inf
                n_bins = ceil(np.cbrt(n_samples / (self.distribution.prior**2)))
            case self.Setting.BINNED_TCE_N_BIN:
                n_positive = np.inf
                n_unlabeled = np.inf
                n_bins = n_samples
            case self.Setting.ECE_CBRT_N_BIN:
                n_bins = ceil(np.cbrt(n_samples))  # Use cube root of n_samples for binning
            case self.Setting.ECE_CBRT_N_POSITIVE_BIN:
                n_bins = ceil(np.cbrt(n_samples / (self.distribution.prior**2)))

        rng = np.random.default_rng(seed)
        ece_values = []
        for _ in range(num_repeats):
            if setting == self.Setting.ECE_CBRT_N_BIN or setting == self.Setting.ECE_CBRT_N_POSITIVE_BIN:
                ece = self.distribution.ece(
                    n_samples, binning_strategy, n_bins=n_bins, seed=int(rng.integers(2**32 - 1))
                )
            else:
                seed = int(rng.integers(2**32 - 1))
                ece = self.distribution.pu_ece(
                    n_positive,
                    n_unlabeled,
                    binning_strategy,
                    n_bins=n_bins,
                    seed=seed,
                    prior_estimation_error=prior_estimation_error,
                )
            ece_values.append(ece)
        ece_array = np.array(ece_values)
        mean_ece = float(np.mean(ece_values))
        sem_ece = float(sem(ece_values))
        tce = self.distribution.tce
        mean_abs_bias = float(np.mean(np.abs(ece_array - tce)))
        sem_abs_bias = float(sem(np.abs(ece_array - tce)))
        lower_90ci_abs_bias = float(np.percentile(np.abs(ece_array - tce), 5))
        upper_90ci_abs_bias = float(np.percentile(np.abs(ece_array - tce), 95))
        return self.Result(
            setting=setting,
            n_samples=n_samples,
            mean_ece=mean_ece,
            sem_ece=sem_ece,
            tce=tce,
            mean_abs_bias=mean_abs_bias,
            sem_abs_bias=sem_abs_bias,
            lower_90ci_abs_bias=lower_90ci_abs_bias,
            upper_90ci_abs_bias=upper_90ci_abs_bias,
        )

    def get_ece_statistics(
        self,
        n_values_to_iterate: list[int],
        num_repeats: int,
        binning_strategy: BinningStrategy,
        settings: list[Setting] | None = None,
        seed: int = 42,
    ) -> dict[Setting, Results]:
        """
        Gets ECE statistics for multiple settings and sample sizes.
        This method runs experiments in parallel using a process pool executor to compute ECE or PU-ECE for each
        setting and sample size. It returns a list of dictionaries containing the results of each experiment.

        Args:
            n_values_to_iterate (list[int]): The list of sample sizes to use for each experiment.
            num_repeats (int): The number of times to repeat each experiment.
            settings (list[Setting] | None): The list of settings to run experiments for. If None, all settings used.
            seed (int): The random seed for reproducibility.

        Returns:
            list[dict]: A list of dictionaries containing the results of each experiment.
        """
        if settings is None:
            settings = list(self.Setting)

        experiment_results: list[ConvergenceExperimentRunner.Result] = []
        # Add serial execution option for debugging
        debug = False  # Set to True to run in serial mode for debugging

        def should_not_be_run(setting: ConvergenceExperimentRunner.Setting, n_samples: int) -> bool:
            """
            Determines if a setting should not be run based on the number of samples.
            Args:
                setting (ConvergenceExperimentRunner.Setting): The experiment setting.
                n_samples (int): The number of samples.
            Returns:
                bool: True if the setting should not be run, False otherwise.
            """
            # Since BINNED_TCE_N_BINS is computationally expensive, we skip it for large n_samples
            if setting == self.Setting.BINNED_TCE_N_BIN and n_samples > 100:
                print(f"Skipping {setting.value} for n_samples={n_samples} due to high computational cost.")
                return True
            return False

        if debug:
            print("Running experiments in serial mode for debugging...")
            # Serial execution for debugging
            # Use a simple tqdm progress bar for serial execution if rich is not preferred here
            for setting_item in settings:
                for n_samples_item in tqdm(n_values_to_iterate, desc=f"Running {setting_item.value}"):
                    if should_not_be_run(setting_item, n_samples_item):
                        continue
                    result = self.get_single_ece_statistics(
                        setting_item, n_samples_item, num_repeats, binning_strategy, seed
                    )
                    experiment_results.append(result)
        else:
            # Parallel execution (default)
            with concurrent.futures.ProcessPoolExecutor() as executor:
                futures_map: dict[concurrent.futures.Future, tuple[ConvergenceExperimentRunner.Setting, int]] = {}
                all_jobs_info: list[tuple[ConvergenceExperimentRunner.Setting, int]] = []

                for setting_item in settings:
                    for n_samples_item in n_values_to_iterate:
                        if should_not_be_run(setting_item, n_samples_item):
                            continue
                        future = executor.submit(
                            self.get_single_ece_statistics,
                            setting_item,
                            n_samples_item,
                            num_repeats,
                            binning_strategy,
                            seed,
                        )
                        futures_map[future] = (setting_item, n_samples_item)
                        all_jobs_info.append((setting_item, n_samples_item))

                self._run_parallel_jobs_with_progress(futures_map, experiment_results)
        results_dict: dict[ConvergenceExperimentRunner.Setting, ConvergenceExperimentRunner.Results] = {}
        for result_item in experiment_results:  # Use the renamed variable
            setting_val = result_item.setting
            if setting_val not in results_dict:
                results_dict[setting_val] = self.Results.from_single_result(result_item)
            else:
                results_dict[setting_val].append(result_item)
        return results_dict

    def _run_parallel_jobs_with_progress(
        self, futures_map: dict[concurrent.futures.Future, tuple[Setting, int]], experiment_results: list[Result]
    ):
        """
        Runs parallel jobs with a progress display showing pending and completed tasks.

        Args:
            futures_map: Dictionary mapping futures to (setting, n_samples) tuples
            experiment_results: List to append results to
        """
        pending_jobs_info = list(futures_map.values())
        total_jobs = len(futures_map)

        progress_bar = self._create_progress_bar()
        overall_task_id: TaskID = progress_bar.add_task("Overall Progress", total=total_jobs)
        live_display_group = Group(progress_bar, self._generate_pending_jobs_table(pending_jobs_info, total_jobs))

        with Live(live_display_group, refresh_per_second=2, vertical_overflow="visible") as live:
            completed_count = 0
            for future in concurrent.futures.as_completed(futures_map):
                job_setting, job_n_samples = futures_map[future]
                try:
                    result_data = future.result()
                    experiment_results.append(result_data)
                except Exception as exc:
                    live.console.print(
                        f"[bold red]Error in job {job_setting.value} (N={job_n_samples}): {exc}[/bold red]"
                    )
                finally:
                    completed_count += 1
                    try:
                        pending_jobs_info.remove((job_setting, job_n_samples))
                    except ValueError:
                        live.console.print(
                            f"[yellow]Warning: Job ({job_setting.value}, {job_n_samples}) "
                            f"not found in pending list for removal.[/yellow]"
                        )

                    # Update progress bar and display
                    self._update_progress_display(
                        progress_bar, overall_task_id, live, pending_jobs_info, completed_count, total_jobs
                    )

            # Final update
            live.refresh()
            time.sleep(0.2)

    def _create_progress_bar(self) -> Progress:
        """Creates and returns a configured progress bar."""
        return Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            TextColumn("[progress.percentage]{task.percentage:>3.1f}%"),
            TextColumn("• Jobs: [bold blue]{task.completed}/{task.total}"),
            TextColumn("• Elapsed: [bold green]{task.elapsed:.1f}s"),
            expand=True,
        )

    def _generate_pending_jobs_table(self, pending_jobs: list[tuple[Setting, int]], total_jobs: int) -> Table:
        """
        Generates a table showing pending jobs.

        Args:
            pending_jobs: List of (setting, n_samples) tuples for pending jobs
            total_jobs: Total number of jobs that were submitted

        Returns:
            Table displaying pending jobs information
        """
        table = Table(title="[bold bright_blue]Pending Jobs[/bold bright_blue]", expand=False, show_lines=False)
        table.add_column("Setting", style="cyan", no_wrap=True, overflow="fold")
        table.add_column("N Samples", style="magenta", no_wrap=True)

        display_limit = 10
        for i, (job_setting, job_n_samples) in enumerate(pending_jobs):
            if i < display_limit:
                table.add_row(str(job_setting.value), str(job_n_samples))
            elif i == display_limit:
                table.add_row(f"... and {len(pending_jobs) - display_limit} more ...", "")
                break

        if not pending_jobs and total_jobs > 0:
            table.add_row("[dim]All jobs completed![/dim]", "")
        elif total_jobs == 0:
            table.add_row("[dim]No jobs to process.[/dim]", "")

        return table

    def _update_progress_display(
        self,
        progress_bar: Progress,
        task_id: TaskID,
        live: Live,
        pending_jobs: list[tuple[Setting, int]],
        completed_count: int,
        total_jobs: int,
    ):
        """
        Updates the progress display with current status.

        Args:
            progress_bar: The progress bar instance
            task_id: Task ID for the progress bar
            live: Live display instance
            pending_jobs: Current list of pending jobs
            completed_count: Number of completed jobs
            total_jobs: Total number of jobs
        """
        description = (
            f"Processing job {completed_count}/{total_jobs}" if completed_count < total_jobs else "All jobs completed!"
        )

        progress_bar.update(task_id, completed=completed_count, description=description)
        live.update(Group(progress_bar, self._generate_pending_jobs_table(pending_jobs, total_jobs)))

    def plot_bias_analysis(
        self,
        plot_title: str,
        results: dict[Setting, Results],
        save_path: str | None = None,
        show_plot: bool = True,
        log_x: bool = True,
    ):
        """
        Plots the bias analysis based on the results of the experiments.
        This method generates a plot showing the mean ECE or PU-ECE values for each setting and sample size, along with
        error bars representing the standard deviation.

        Args:
            plot_title (str): The title of the plot.
            results (dict[str, Results]): The results of the experiments to plot.
            save_path (str | None): The path to save the plot. If None, the plot will not be saved.
            show_plot (bool): Whether to display the plot.
            log_x (bool): Whether to use logarithmic scale for the x-axis.
        """
        fig, ax = plt.subplots(figsize=(10, 6))

        for setting, result in results.items():
            n_samples = result.n_samples
            mean_bias = np.array(result.mean_ece) - result.tce
            sem_bias = np.array(result.sem_ece)
            (line_plot,) = ax.plot(n_samples, mean_bias, marker=setting.to_marker(), label=setting.value)
            ax.fill_between(
                n_samples,
                mean_bias - sem_bias,
                mean_bias + sem_bias,
                alpha=0.1,
                color=line_plot.get_color(),
            )
        ax.set_xlabel("$N$")
        ax.set_ylabel("Bias")
        ax.set_title(plot_title)
        ax.legend()
        ax.grid(True, which="both", linestyle="--", linewidth=0.5)
        if log_x:
            ax.set_xscale("log")
        if save_path:
            plt.savefig(save_path)
        if show_plot:
            plt.show()

    def plot_absolute_bias_analysis(
        self,
        plot_title: str,
        results: dict[Setting, Results],
        save_path: str | None = None,
        show_plot: bool = True,
        log_x: bool = True,
        log_y: bool = False,
    ):
        """
        Plots the absolute bias analysis based on the results of the experiments.
        This method generates a plot showing the mean ECE or PU-ECE values for each setting and sample size, along with
        error bars representing the standard deviation.

        Args:
            plot_title (str): The title of the plot.
            results (dict[str, Results]): The results of the experiments to plot.
            save_path (str | None): The path to save the plot. If None, the plot will not be saved.
            show_plot (bool): Whether to display the plot.
            log_x (bool): Whether to use logarithmic scale for the x-axis.
            log_y (bool): Whether to use logarithmic scale for the y-axis.
        """
        fig, ax = plt.subplots(figsize=(6, 4.5))

        min_x, max_x = np.inf, -np.inf
        min_y, max_y = np.inf, -np.inf

        for result in results.values():
            min_x = min(min_x, np.min(result.n_samples))
            max_x = max(max_x, np.max(result.n_samples))

        if min_x == np.inf or max_x == -np.inf:
            raise ValueError("No valid sample sizes found in results.")

        if log_x and log_y:
            # Add reference lines for O(N^{-1/3}) convergence rate
            x_ref = 1000  # Reference point for N
            y_refs = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6]  # Reference y-values at x_ref
            x_vals = np.logspace(np.log10(min_x), np.log10(max_x), 100)

            for y_ref in y_refs:
                # y = C * x^{-1/3}, where C = y_ref * x_ref^{1/3}
                y_vals = y_ref * (x_vals / x_ref) ** (-1 / 3)
                ax.plot(x_vals, y_vals, "k--", alpha=0.5, linewidth=0.8)

        for setting, result in results.items():
            n_samples = result.n_samples
            mean_abs_bias = np.array(result.mean_abs_bias)
            sem_abs_bias = np.array(result.sem_abs_bias)
            (line_plot,) = ax.plot(n_samples, mean_abs_bias, marker=setting.to_marker(), label=setting.value)
            if log_y:
                y_lower = np.array(result.lower_90ci_abs_bias)
                y_upper = np.array(result.upper_90ci_abs_bias)
            else:
                y_lower = mean_abs_bias - sem_abs_bias
                y_upper = mean_abs_bias + sem_abs_bias
            ax.fill_between(
                n_samples,
                y_lower,
                y_upper,
                alpha=0.1,
                color=line_plot.get_color(),
            )

            min_y = min(min_y, np.min(y_lower))
            max_y = max(max_y, np.max(y_upper))

        ax.set_xlabel("$N$")
        ax.set_ylabel("Total bias")
        # ax.set_title(plot_title)
        ax.legend()
        ax.grid(True, which="major", linestyle=":", linewidth=0.1)

        if np.isfinite(min_x) and np.isfinite(max_x):
            ax.set_xlim(min_x, max_x)
        if np.isfinite(min_y) and np.isfinite(max_y):
            ax.set_ylim(min_y, max_y)

        if log_x:
            ax.set_xscale("log")
        else:
            ax.set_xlim(left=0)  # Ensure x-axis starts at 0
        if log_y:
            ax.set_yscale("log")
        else:
            ax.set_ylim(bottom=0)  # Ensure y-axis starts at 0
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path)
        if show_plot:
            plt.show()
        else:
            plt.close(fig)

    def plot_prior_error_analysis(
        self,
        plot_title: str,
        prior: float,
        results: dict[Setting, Results],
        save_path: str | None = None,
        show_plot: bool = True,
        log_x: bool = True,
        log_y: bool = False,
    ):
        """
        Plots the absolute bias analysis based on the results of the experiments.
        This method generates a plot showing the mean ECE or PU-ECE values for each setting and sample size, along with
        error bars representing the standard deviation.

        Args:
            plot_title (str): The title of the plot.
            prior (float): The true class prior used in the experiments.
            results (dict[str, Results]): The results of the experiments to plot.
            save_path (str | None): The path to save the plot. If None, the plot will not be saved.
            show_plot (bool): Whether to display the plot.
            log_x (bool): Whether to use logarithmic scale for the x-axis.
            log_y (bool): Whether to use logarithmic scale for the y-axis.
        """
        fig, ax = plt.subplots(figsize=(6, 4.5))

        min_x, max_x = np.inf, -np.inf
        min_y, max_y = np.inf, -np.inf

        for result in results.values():
            min_x = min(min_x, np.min(result.n_samples))
            max_x = max(max_x, np.max(result.n_samples))

        if min_x == np.inf or max_x == -np.inf:
            raise ValueError("No valid sample sizes found in results.")

        if log_x and log_y:
            # Add reference lines for O(N^{-1/3}) convergence rate
            x_ref = 1000  # Reference point for N
            y_refs = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6]  # Reference y-values at x_ref
            x_vals = np.logspace(np.log10(min_x), np.log10(max_x), 100)

            for y_ref in y_refs:
                # y = C * x^{-1/3}, where C = y_ref * x_ref^{1/3}
                y_vals = y_ref * (x_vals / x_ref) ** (-1 / 3)
                ax.plot(x_vals, y_vals, "k--", alpha=0.5, linewidth=0.8)

        for setting, result in results.items():
            n_samples = result.n_samples
            mean_abs_bias = np.array(result.mean_abs_bias)
            sem_abs_bias = np.array(result.sem_abs_bias)
            (line_plot,) = ax.plot(n_samples, mean_abs_bias, marker=setting.to_marker(), label=setting.value)
            if log_y:
                y_lower = np.array(result.lower_90ci_abs_bias)
                y_upper = np.array(result.upper_90ci_abs_bias)
            else:
                y_lower = mean_abs_bias - sem_abs_bias
                y_upper = mean_abs_bias + sem_abs_bias
            ax.fill_between(
                n_samples,
                y_lower,
                y_upper,
                alpha=0.1,
                color=line_plot.get_color(),
            )

            min_y = min(min_y, np.min(y_lower))
            max_y = max(max_y, np.max(y_upper))

        ax.hlines(prior * 0.05, min_x, max_x, colors="black", linestyles="dotted", label="5% prior error", alpha=0.7)
        ax.hlines(prior * 0.1, min_x, max_x, colors="black", linestyles="dashed", label="10% prior error", alpha=0.7)

        ax.set_xlabel("$N$")
        ax.set_ylabel("Total bias")
        # ax.set_title(plot_title)
        ax.legend()
        ax.grid(True, which="major", linestyle=":", linewidth=0.1)

        if np.isfinite(min_x) and np.isfinite(max_x):
            ax.set_xlim(min_x, max_x)
        if np.isfinite(min_y) and np.isfinite(max_y):
            ax.set_ylim(min_y, max_y)

        if log_x:
            ax.set_xscale("log")
        else:
            ax.set_xlim(left=0)  # Ensure x-axis starts at 0
        if log_y:
            ax.set_yscale("log")
        else:
            ax.set_ylim(bottom=0)  # Ensure y-axis starts at 0
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path)
        if show_plot:
            plt.show()
        else:
            plt.close(fig)

    def run_experiment(
        self, base_path: str, binning_strategy: BinningStrategy, num_repeats: int = 10, seed: int = 42
    ) -> dict[Setting, Results]:
        """
        Runs the convergence experiments.

        This method defines the sample sizes to iterate over and runs the experiments to compute ECE or PU-ECE values.

        Args:
            num_repeats (int): The number of times to repeat each experiment.
            seed (int): The random seed for reproducibility.
        """
        n_values_to_iterate = [10, 32, 100, 316, 1_000, 3_162, 10_000, 31_622, 100_000]
        results = self.get_ece_statistics(n_values_to_iterate, num_repeats, binning_strategy, seed=seed)
        ece_vs_pu_ece_results = {
            self.Setting.PU_ECE_INFINITE_UNLABELED: results[self.Setting.PU_ECE_INFINITE_UNLABELED],
            self.Setting.PU_ECE_INFINITE_POSITIVE: results[self.Setting.PU_ECE_INFINITE_POSITIVE],
            self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED: results[self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED],
            self.Setting.ECE_CBRT_N_BIN: results[self.Setting.ECE_CBRT_N_BIN],
            self.Setting.ECE_CBRT_N_POSITIVE_BIN: results[self.Setting.ECE_CBRT_N_POSITIVE_BIN],
        }
        save_path = os.path.join(base_path, f"{self.dataset_name.name}_ece_vs_pu_ece_{binning_strategy.value}.pdf")
        plot_title = f"ECE vs PU-ECE Convergence ({self.dataset_name.value})"
        self.plot_absolute_bias_analysis(
            plot_title, ece_vs_pu_ece_results, save_path=save_path, show_plot=False, log_x=True, log_y=True
        )

        prior_estimation_error_results = {
            self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED: results[self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED],
            self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_5PCT: results[
                self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_5PCT
            ],
            self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_5PCT: results[
                self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_5PCT
            ],
            self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_10PCT: results[
                self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_OVERESTIMATE_10PCT
            ],
            self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_10PCT: results[
                self.Setting.PU_ECE_N_POSITIVE_10N_UNLABELED_PRIOR_UNDERSTIMATE_10PCT
            ],
        }
        save_path = os.path.join(
            base_path, f"{self.dataset_name.name}_prior_estimation_error_{binning_strategy.value}.pdf"
        )
        self.plot_prior_error_analysis(
            plot_title,
            prior=self.distribution.prior,
            results=prior_estimation_error_results,
            save_path=save_path,
            show_plot=False,
            log_x=True,
            log_y=False,
        )

        n_pos_convergence_results = {
            self.Setting.PU_ECE_100_POSITIVE_N_UNLABELED: results[self.Setting.PU_ECE_100_POSITIVE_N_UNLABELED],
            self.Setting.PU_ECE_1000_POSITIVE_N_UNLABELED: results[self.Setting.PU_ECE_1000_POSITIVE_N_UNLABELED],
            self.Setting.PU_ECE_10000_POSITIVE_N_UNLABELED: results[self.Setting.PU_ECE_10000_POSITIVE_N_UNLABELED],
            self.Setting.PU_ECE_INFINITE_UNLABELED: results[self.Setting.PU_ECE_INFINITE_UNLABELED],
        }
        save_path = os.path.join(base_path, f"{self.dataset_name.name}_n_pos_convergence_{binning_strategy.value}.pdf")
        plot_title = f"PU-ECE Convergence with Varying Positive Samples ({self.dataset_name.value})"
        self.plot_absolute_bias_analysis(
            plot_title, n_pos_convergence_results, save_path=save_path, show_plot=False, log_x=True, log_y=True
        )
        return results


if __name__ == "__main__":
    num_repeats = 100
    seed = 42
    results = {}
    for binning_strategy in [BinningStrategy.UWB, BinningStrategy.UMB]:
        for dataset in DatasetNames:
            print(f"Running convergence experiments for {dataset.value} distribution...")
            runner = ConvergenceExperimentRunner(dataset)
            results[dataset] = runner.run_experiment(
                base_path=".", binning_strategy=binning_strategy, num_repeats=num_repeats, seed=seed
            )
    print("Experiment completed.")
