import numpy as np
import inspect
import random
import os
from typing import Union, Any, Tuple, Dict, Optional, Callable, List
import matplotlib.axes
import types
import math
from scipy.special import gamma as gamma_fn, zeta

# Utils for passing distributions to environments


def check_for_allowed_dist(rng: np.random.Generator, dist: str, **kwargs: dict) -> int:
    """
    Validates the input parameters for the sample_from_dist function.

    Parameters:
    - rng (numpy.random.Generator): The numpy random generator to be used in sample_from_dist.
    - kwargs (dict): The keyword arguments to be used in sample_from_dist.

    Raises:
    - ValueError: If any of the input parameters are invalid.
    - TypeError: If any of the input types are invalid.
    """

    if isinstance(rng, np.random.Generator):
        dist_func = getattr(rng, dist, None)
        if dist_func is None:
            raise ValueError(
                f"Distribution '{dist}' is not supported with np.random.Generator!"
            )
        sig = inspect.signature(dist_func)
        for key in kwargs:
            if key not in sig.parameters:
                raise ValueError(
                    f"Invalid argument {key} for distribution {dist} specified!"
                )
        for param in sig.parameters.values():
            if (
                param.default == inspect.Parameter.empty
                and param.name not in kwargs
                and param.name != "size"
            ):
                raise ValueError(
                    f"Missing required argument {param.name} for distribution {dist}!"
                )
    else:
        raise TypeError(
            "The only random number generator currently supported and tested is np.random.Generator!"
        )
    return 1


def sample_from_dist(
    rng: np.random.Generator, dist: str, size: int, **kwargs: dict
) -> np.ndarray:
    """
    Samples from a distribution with parameters to be specified using a numpy random number generator.
    Use this function only when you are sure that the keyword arguments passed match the distribution
    and the distribution is implemented in numpy random module. You can run check_for_allowed_dist if
    you are unsure.

    Parameters:
    - rng (numpy.random.Generator): The numpy random generator to be used in sample_from_dist.
    - dist (str): The distribution to be passed to the random number generator.
    - size (int): The number of random samples to be drawn.
    - kwargs (dict): The keyword arguments for the chosen distribution.

    Returns:
    - sample (numpy ndarray): The drawn sample(s) in a numpy array.

    Raises:
    - ValueError: If no distribution is specified.
    """

    # Dynamically get the distribution method from the generator
    dist_func = getattr(rng, dist, None)
    if dist_func is None:
        raise ValueError(f"Distribution '{dist}' is not supported.")
    return dist_func(size=size, **kwargs)


def get_numpy_distribution_mean(
    dist_name: str, **kwargs
) -> Optional[Union[float, np.ndarray]]:
    """
    Computes the theoretical mean of a NumPy distribution, given its name and parameters.
    Supports distributions from numpy.random.Generator, assuming numeric input for all params.

    Returns None if the distribution is unsupported or inputs are insufficient.

    Parameters:
    - dist_name (str): Name of the distribution method from numpy.random.Generator.
    - kwargs: Parameters of the distribution (e.g. n, p, loc, scale, a, etc.)

    Returns:
    - Mean (float, np.ndarray, or None)
    """

    params = kwargs.copy()

    try:
        match dist_name:
            case "binomial":
                n = params.get("n", 1)
                p = params.get("p", 0.5)
                return n * p

            case "normal":
                loc = params.get("loc", 0.0)
                return loc

            case "uniform":
                low = params.get("low", 0.0)
                high = params.get("high", 1.0)
                return 0.5 * (low + high)

            case "integers":
                low = params.get("low", 0)
                high = params.get("high", None)
                if high is None:
                    return None
                return 0.5 * (low + high - 1)

            case "choice":
                a = np.asarray(params.get("a", []))
                if a.size == 0:
                    return None
                p = np.asarray(params.get("p")) if "p" in params else None
                return float(np.average(a, weights=p))

            case "poisson":
                lam = params.get("lam", 1.0)
                return lam

            case "exponential":
                scale = params.get("scale", 1.0)
                return scale

            case "gamma":
                shape = params.get("shape", 1.0)
                scale = params.get("scale", 1.0)
                return shape * scale

            case "beta":
                a = params.get("a", 0.5)
                b = params.get("b", 0.5)
                return a / (a + b) if a + b != 0 else None

            case "geometric":
                p = params.get("p", 0.5)
                return 1 / p if p > 0 else None

            case "laplace":
                loc = params.get("loc", 0.0)
                return loc

            case "logistic":
                loc = params.get("loc", 0.0)
                return loc

            case "lognormal":
                mean = params.get("mean", 0.0)
                sigma = params.get("sigma", 1.0)
                return math.exp(mean + 0.5 * sigma**2)

            case "pareto":
                a = params.get("a", 1.0)
                return a / (a - 1) if a > 1 else None

            case "rayleigh":
                scale = params.get("scale", 1.0)
                return scale * math.sqrt(math.pi / 2)

            case "triangular":
                left = params.get("left", 0.0)
                mode = params.get("mode", 0.5)
                right = params.get("right", 1.0)
                return (left + mode + right) / 3

            case "weibull":
                a = params.get("a", 1.0)
                return math.gamma(1 + 1 / a) if a > 0 else None

            case "chisquare":
                df = params.get("df", 1.0)
                return df

            case "multivariate_normal":
                mean = params.get("mean", None)
                return np.asarray(mean) if mean is not None else None

            case "dirichlet":
                alpha = np.asarray(params.get("alpha", []))
                if alpha.ndim != 1 or alpha.size == 0:
                    return None
                alpha_sum = np.sum(alpha)
                return alpha / alpha_sum

            case "zipf":
                a = params.get("a", 2.0)
                if a <= 2:
                    return None
                from scipy.special import zeta

                return zeta(a - 1) / zeta(a)

            case "hypergeometric":
                ngood = params.get("ngood", None)
                nbad = params.get("nbad", None)
                nsample = params.get("nsample", None)
                if None in (ngood, nbad, nsample):
                    return None
                return nsample * ngood / (ngood + nbad)

            case "negative_binomial":
                n = params.get("n", 1)
                p = params.get("p", 0.5)
                return n * (1 - p) / p

            case _:
                return None  # Unknown distribution
    except Exception:
        return None  # Gracefully handle all invalid inputs


def adjust_mean_lambda_and_args(
    target_mean: Union[int, float], dist_name: str, kwargs: Dict
) -> None:
    """
    Modifies kwargs in-place so that the expected mean is the target mean.
    Raises ValueError if this is not possible.
    """
    match dist_name:
        case "normal" | "laplace" | "logistic":
            kwargs["loc"] = target_mean

        case "uniform":
            low = kwargs.get("low", 0.0)
            high = kwargs.get("high", 1.0)
            current_mean = 0.5 * (low + high)
            shift = current_mean - target_mean
            kwargs["low"] = low - shift
            kwargs["high"] = high - shift

        case "integers":
            low = kwargs.get("low", 0)
            high = kwargs.get("high", None)
            current_mean = 0.5 * (low + high - 1)
            delta = current_mean - target_mean
            kwargs["low"] = low - int(delta)
            kwargs["high"] = high - int(delta)

        case "choice":
            a = kwargs.get("a", [])
            if not a:
                raise ValueError("Cannot adjust empty array for 'choice'.")
            p = kwargs.get("p", None)
            weights = np.asarray(p) if p is not None else None
            mean = np.average(a, weights=weights)
            delta = float(mean - target_mean)
            kwargs["a"] = [x - delta for x in a]

        case "binomial":
            n = kwargs.get("n", 1)
            max_p = target_mean / n
            kwargs["p"] = max_p

        case "poisson":
            kwargs["lam"] = target_mean

        case "exponential":
            kwargs["scale"] = target_mean

        case "rayleigh":
            scale = kwargs.get("scale", 1.0)
            current_mean = scale * math.sqrt(math.pi / 2)
            max_scale = target_mean / math.sqrt(math.pi / 2)
            kwargs["scale"] = min(scale, max_scale)

        case "gamma":
            shape = kwargs.get("shape", 1.0)
            max_scale = target_mean / shape
            kwargs["scale"] = min(scale, max_scale)

        case "weibull":
            for a_new in np.linspace(0.1, 100, 1000):
                mean = gamma_fn(1 + 1 / a_new)
                if target_mean <= mean <= target_mean - 1:
                    kwargs["a"] = float(a_new)
                    break
                else:
                    raise ValueError("Cannot adjust Weibull mean to lie within bounds.")

        case "beta":
            b_new = a * (1 - target_mean) / target_mean
            if b_new <= 0:
                raise ValueError("Adjusted beta parameter b becomes non-positive.")
            kwargs["b"] = b_new

        case "lognormal":
            sigma = kwargs.get("sigma", 1.0)
            mu_new = math.log(target_mean) - 0.5 * sigma**2
            kwargs["mean"] = mu_new

        case "triangular":
            left = kwargs.get("left", 0.0)
            mode = kwargs.get("mode", 0.5)
            right = kwargs.get("right", 1.0)
            current_mean = (left + mode + right) / 3
            shift = current_mean - target_mean
            kwargs["left"] = left - shift
            kwargs["mode"] = mode - shift
            kwargs["right"] = right - shift

        case "geometric":
            new_p = 1 / target_mean
            if new_p > 1.0:
                raise ValueError("Cannot adjust geometric mean to lie within bounds.")
            kwargs["p"] = new_p

        case "chisquare":
            kwargs["df"] = target_mean

        case "multivariate_normal":
            mean = kwargs.get("mean", None)
            if mean is None:
                raise ValueError("Missing 'mean' in kwargs for multivariate_normal.")
            mean = np.asarray(mean)
            mean_clipped = np.clip(mean, target_mean, target_mean)
            mean_list = mean_clipped.tolist()
            kwargs["mean"] = [float(item) for item in mean_list]

        case "dirichlet":
            alpha = kwargs.get("alpha", None)
            if alpha is None:
                raise ValueError("Missing 'alpha' in kwargs for dirichlet.")
            alpha = np.asarray(alpha)
            alpha_sum = alpha.sum()
            mean = alpha / alpha_sum
            if np.any((mean < target_mean) | (mean > target_mean)):
                factor = min(target_mean / np.max(mean), target_mean / np.min(mean))
                new_alpha = (alpha * factor).tolist()
                kwargs["alpha"] = [float(item) for item in new_alpha]

        case "zipf":
            a = kwargs.get("a", 2.0)
            if a <= 2:
                raise ValueError("Cannot compute finite mean for a <= 2 in Zipf.")
            while a > 2 and zeta(a - 1) / zeta(a) > target_mean:
                a += 0.1
            kwargs["a"] = a

        case "hypergeometric":
            ngood = kwargs.get("ngood", None)
            nbad = kwargs.get("nbad", None)
            nsample = kwargs.get("nsample", None)
            if None in (ngood, nbad, nsample):
                raise ValueError("Missing parameter for hypergeometric.")
            current_mean = nsample * ngood / (ngood + nbad)
            scale = target_mean / current_mean
            kwargs["ngood"] = int(ngood * scale)

        case "negative_binomial":
            n = kwargs.get("n", 1)
            new_p = n / (n + target_mean)
            if new_p > 1.0:
                raise ValueError(
                    "Cannot adjust negative_binomial mean to lie within bounds."
                )
            kwargs["p"] = new_p

        case "pareto":
            new_a = -1 * (1 / (1 / target_mean - 1))
            kwargs["a"] = new_a

        case _:
            return None


def adjust_distribution_kwargs_to_bound_mean_inplace(
    dist_name: str,
    kwargs: dict,
    min_mean: float = float("-inf"),
    max_mean: float = float("inf"),
) -> None:
    """
    Modifies kwargs in-place so that the expected mean of the distribution stays
    within [min_mean, max_mean].
    Raises ValueError if this is not possible.
    """

    match dist_name:

        case "normal" | "laplace" | "logistic":
            loc = kwargs.get("loc", 0.0)
            if loc > max_mean:
                kwargs["loc"] = max_mean
            elif loc < min_mean:
                kwargs["loc"] = min_mean

        case "uniform":
            low = kwargs.get("low", 0.0)
            high = kwargs.get("high", 1.0)
            current_mean = 0.5 * (low + high)
            if current_mean > max_mean or current_mean < min_mean:
                target_mean = min(max(current_mean, min_mean), max_mean)
                shift = current_mean - target_mean
                kwargs["low"] = low - shift
                kwargs["high"] = high - shift

        case "integers":
            low = kwargs.get("low", 0)
            high = kwargs.get("high", None)
            mean = 0.5 * (low + high - 1)
            target_mean = min(max(mean, min_mean), max_mean)
            delta = mean - target_mean
            kwargs["low"] = low - int(delta)
            kwargs["high"] = high - int(delta)

        case "choice":
            a = kwargs.get("a", [])
            if not a:
                raise ValueError("Cannot adjust empty array for 'choice'.")
            p = kwargs.get("p", None)
            weights = np.asarray(p) if p is not None else None
            mean = np.average(a, weights=weights)
            target_mean = min(max(mean, min_mean), max_mean)
            delta = float(mean - target_mean)
            kwargs["a"] = [x - delta for x in a]

        case "binomial":
            n = kwargs.get("n", 1)
            p = kwargs.get("p", 0.5)
            current_mean = n * p
            target_mean = min(max(current_mean, min_mean), max_mean)
            max_p = target_mean / n
            kwargs["p"] = max_p

        case "poisson":
            lam = kwargs.get("lam", 1.0)
            kwargs["lam"] = min(max(lam, min_mean), max_mean)

        case "exponential":
            scale = kwargs.get("scale", 1.0)
            kwargs["scale"] = min(max(scale, min_mean), max_mean)

        case "rayleigh":
            scale = kwargs.get("scale", 1.0)
            current_mean = scale * math.sqrt(math.pi / 2)
            target_mean = min(max(current_mean, min_mean), max_mean)
            max_scale = target_mean / math.sqrt(math.pi / 2)
            kwargs["scale"] = min(scale, max_scale)

        case "gamma":
            shape = kwargs.get("shape", 1.0)
            scale = kwargs.get("scale", 1.0)
            current_mean = shape * scale
            target_mean = min(max(current_mean, min_mean), max_mean)
            max_scale = target_mean / shape
            kwargs["scale"] = min(scale, max_scale)

        case "weibull":
            a = kwargs.get("a", 1.0)
            current_mean = gamma_fn(1 + 1 / a)
            if current_mean > max_mean or current_mean < min_mean:
                for a_new in np.linspace(0.1, 100, 1000):
                    mean = gamma_fn(1 + 1 / a_new)
                    if min_mean <= mean <= max_mean:
                        kwargs["a"] = float(a_new)
                        break
                else:
                    raise ValueError("Cannot adjust Weibull mean to lie within bounds.")

        case "beta":
            a = kwargs.get("a", 1.0)
            b = kwargs.get("b", 1.0)
            current_mean = a / (a + b)
            if not (min_mean <= current_mean <= max_mean):
                target_mean = min(max(current_mean, min_mean), max_mean)
                b_new = a * (1 - target_mean) / target_mean
                if b_new <= 0:
                    raise ValueError("Adjusted beta parameter b becomes non-positive.")
                kwargs["b"] = b_new

        case "lognormal":
            mu = kwargs.get("mean", 0.0)
            sigma = kwargs.get("sigma", 1.0)
            current_mean = math.exp(mu + 0.5 * sigma**2)
            if not (min_mean <= current_mean <= max_mean):
                target_mean = min(max(current_mean, min_mean), max_mean)
                mu_new = math.log(target_mean) - 0.5 * sigma**2
                kwargs["mean"] = mu_new

        case "triangular":
            left = kwargs.get("left", 0.0)
            mode = kwargs.get("mode", 0.5)
            right = kwargs.get("right", 1.0)
            current_mean = (left + mode + right) / 3
            target_mean = min(max(current_mean, min_mean), max_mean)
            shift = current_mean - target_mean
            kwargs["left"] = left - shift
            kwargs["mode"] = mode - shift
            kwargs["right"] = right - shift

        case "geometric":
            p = kwargs.get("p", 0.5)
            current_mean = 1 / p
            target_mean = min(max(current_mean, min_mean), max_mean)
            new_p = 1 / target_mean
            if new_p > 1.0:
                raise ValueError("Cannot adjust geometric mean to lie within bounds.")
            kwargs["p"] = new_p

        case "chisquare":
            df = kwargs.get("df", 1.0)
            kwargs["df"] = min(max(df, min_mean), max_mean)

        case "multivariate_normal":
            mean = kwargs.get("mean", None)
            if mean is None:
                raise ValueError("Missing 'mean' in kwargs for multivariate_normal.")
            mean = np.asarray(mean)
            mean_clipped = np.clip(mean, min_mean, max_mean)
            new_mean = mean_clipped.tolist()
            kwargs["mean"] = [float(item) for item in new_mean]

        case "dirichlet":
            alpha = kwargs.get("alpha", None)
            if alpha is None:
                raise ValueError("Missing 'alpha' in kwargs for dirichlet.")
            alpha = np.asarray(alpha)
            alpha_sum = alpha.sum()
            mean = alpha / alpha_sum
            if np.any((mean < min_mean) | (mean > max_mean)):
                factor = min(max_mean / np.max(mean), min_mean / np.min(mean))
                new_alpha = (alpha * factor).tolist()
                kwargs["alpha"] = [float(item) for item in new_alpha]

        case "zipf":
            a = kwargs.get("a", 2.0)
            if a <= 2:
                raise ValueError("Cannot compute finite mean for a <= 2 in Zipf.")
            while a > 2 and zeta(a - 1) / zeta(a) > max_mean:
                a += 0.1
            kwargs["a"] = a

        case "hypergeometric":
            ngood = kwargs.get("ngood", None)
            nbad = kwargs.get("nbad", None)
            nsample = kwargs.get("nsample", None)
            if None in (ngood, nbad, nsample):
                raise ValueError("Missing parameter for hypergeometric.")
            current_mean = nsample * ngood / (ngood + nbad)
            target_mean = min(max(current_mean, min_mean), max_mean)
            scale = target_mean / current_mean
            kwargs["ngood"] = int(ngood * scale)

        case "negative_binomial":
            n = kwargs.get("n", 1)
            p = kwargs.get("p", 0.5)
            mean = n * (1 - p) / p
            target_mean = min(max(mean, min_mean), max_mean)
            new_p = n / (n + target_mean)
            if new_p > 1.0:
                raise ValueError(
                    "Cannot adjust negative_binomial mean to lie within bounds."
                )
            kwargs["p"] = new_p

        case "pareto":
            a = kwargs.get("a", 1)
            mean = a / (a - 1)
            target_mean = min(max(mean, min_mean), max_mean)
            new_a = -(1 / (1 / target_mean - 1))
            kwargs["a"] = new_a

        # Add more cases if needed
        case _:
            raise NotImplementedError(
                f"Distribution '{dist_name}' is not supported for mean adjustment."
            )


# Utils for learning rate and exploration schedules


def check_for_schedule_allowed(initial_check: bool = True, **kwargs: dict) -> int:
    """
    Validates the input parameters for the schedule function.

    Parameters:
    - initial_check (bool): True, if this is the first check and the schedule function has not
      already been applied to the passed dictionary.
    - kwargs (dict): The keyword arguments to be used in schedule.

    Raises:
    - ValueError: If any of the input parameters are invalid.
    - TypeError: If any of the input types are invalid.
    """

    # Implemented stuff
    mandatory_implemented_keys = ["initial_rate", "current_rate", "mode", "mode_kwargs"]
    implemented_modes = ["constant", "linear", "rate"]
    allowed_mode_kwargs = {
        "constant": ["final_rate"],
        "linear": ["final_rate", "num_steps", "slope"],
        "rate": ["rate_fct", "iteration_num", "final_rate"],
    }
    # Exactly the mandatory keys are contained
    if isinstance(kwargs, dict):
        for key in mandatory_implemented_keys:
            if not key in kwargs.keys():
                raise ValueError(f"Key {key} is missing!")
        for key in kwargs.keys():
            if not key in mandatory_implemented_keys:
                raise ValueError(f"Keyword {key} is appearing but not implemented!")
    else:
        raise TypeError(f"Keyword arguments need to be passed in a dictionary!")

    # # Initial rate is a number between 0 and 1
    # if isinstance(kwargs["initial_rate"],(float,int)):
    #     if not (0 <= kwargs["initial_rate"] <= 1):
    #         raise ValueError("Initial rate needs to be between 0 and 1!")
    # else:
    #     raise TypeError("Initial rate needs to be a number!")

    # Current rate is a number between 0 and initial rate
    if isinstance(kwargs["current_rate"], (float, int)):
        if not (0 <= kwargs["current_rate"] <= kwargs["initial_rate"]):
            raise ValueError("Current rate needs to be between 0 and initial rate!")
    else:
        raise TypeError("Current rate needs to be a number!")

    # If it is the first check, the initial and current rate need to coincide
    if initial_check:
        if not (kwargs["current_rate"] == kwargs["initial_rate"]):
            raise ValueError(
                "Current and initial rate need to coincide in the beginning!"
            )

    # Mode is implemented
    if isinstance(kwargs["mode"], str):
        if not (kwargs["mode"] in implemented_modes):
            raise ValueError(f"Mode {kwargs['mode']} is not implemented!")
    else:
        raise TypeError("Mode needs to be a string!")

    # Keyword arguments for mode are supported
    if isinstance(kwargs["mode_kwargs"], dict):
        for key in kwargs["mode_kwargs"].keys():
            if not key in allowed_mode_kwargs[kwargs["mode"]]:
                raise ValueError(f"Keyword {key} is appearing but not implemented!")
        for key in allowed_mode_kwargs[kwargs["mode"]]:
            if not key in kwargs["mode_kwargs"].keys():
                raise ValueError(f"Keyword {key} is missing!")
    else:
        raise TypeError(
            "Keyword arguments for the mode must be passed as a dictionary!"
        )

    # Keyword arguments for mode take the right values
    if kwargs["mode"] == "constant":
        pass
    elif kwargs["mode"] == "linear":
        if isinstance(kwargs["mode_kwargs"]["final_rate"], (int, float)):
            if not (0 <= kwargs["mode_kwargs"]["final_rate"] <= kwargs["current_rate"]):
                raise ValueError(
                    "For the schedule mode linear, the final rate needs to be less than the initial and current rates!"
                )
        else:
            raise TypeError(
                "For the schedule mode linear, the final rate needs to be a numerical value!"
            )
        if isinstance(kwargs["mode_kwargs"]["num_steps"], int):
            if not (0 < kwargs["mode_kwargs"]["num_steps"]):
                raise ValueError(
                    "For the schedule mode linear, the number of steps needs to be positive!"
                )
        else:
            raise TypeError(
                "For the schedule mode linear, the number of steps needs to be an integer!"
            )
        if isinstance(kwargs["mode_kwargs"]["slope"], (int, float)):
            if not (
                0 <= kwargs["mode_kwargs"]["slope"]
                or kwargs["mode_kwargs"]["slope"] == -1
            ):
                raise ValueError(
                    "For the schedule mode linear, the slope needs to either be non-negative, or take the value -1 for initialization via final rate and number of steps!"
                )
        else:
            raise TypeError(
                "For the schedule mode linear, the slope must be numerical!"
            )
    elif kwargs["mode"] == "rate":
        if callable(kwargs["mode_kwargs"]["rate_fct"]):
            if not kwargs["mode_kwargs"]["rate_fct"].__name__ == "<lambda>":
                raise TypeError(
                    "For the schedule mode rate, the rate function needs to be passed as a lambda function!"
                )
        else:
            raise TypeError(
                "For the schedule mode rate, the rate function needs to be a callable!"
            )
        if isinstance(kwargs["mode_kwargs"]["iteration_num"], int):
            if not kwargs["mode_kwargs"]["iteration_num"] > 0:
                raise ValueError(
                    "For the schedule mode rate, the iteration number needs to be positive!"
                )
            if initial_check:
                if not kwargs["mode_kwargs"]["iteration_num"] == 1:
                    raise ValueError(
                        "In the beginning, the number of iterations done needs to be one!"
                    )
        else:
            raise TypeError(
                "For the schedule mode rate, the iteration number needs to be numerical!"
            )
    else:
        raise ValueError(
            "If you want to implement a new schedule mode please specify a type check in check_for_schedule_allowed!"
        )
    return 1


def schedule(reset_schedule: bool = False, **kwargs: dict) -> dict:
    """
    Returns the next scheduled rate for a dictionary containing the initial rate, the current
    rate, the mode, and the necessary keyword arguments for the mode to be applied.

    Parameters:
    - reset_schedule (bool): If True, resets the schedule to the initial state.
    - kwargs (dict): The keyword arguments for the update of the rate.

    Returns:
    - kwargs (dict): The keyword arguments dictionary, but with updated current rate.
    """

    if reset_schedule:
        kwargs["current_rate"] = kwargs["initial_rate"]
        if kwargs["mode"] == "rate":
            kwargs["mode_kwargs"]["iteration_num"] = 1
        return kwargs

    # Normal scheduling
    if kwargs["current_rate"] == 0:
        return kwargs
    elif kwargs["mode"] == "constant":
        next_rate = kwargs["current_rate"]
    elif kwargs["mode"] == "linear":
        if kwargs["mode_kwargs"]["slope"] == -1:
            kwargs["mode_kwargs"]["slope"] = (
                kwargs["initial_rate"] - kwargs["mode_kwargs"]["final_rate"]
            ) / (kwargs["mode_kwargs"]["num_steps"] - 1)
        next_rate = kwargs["current_rate"] - kwargs["mode_kwargs"]["slope"]
        if next_rate < kwargs["mode_kwargs"]["final_rate"]:
            kwargs["mode_kwargs"]["slope"] = 0
            next_rate = kwargs["mode_kwargs"]["final_rate"]
    elif kwargs["mode"] == "rate":
        next_rate = kwargs["initial_rate"] * kwargs["mode_kwargs"]["rate_fct"](
            kwargs["mode_kwargs"]["iteration_num"]
        )
        kwargs["mode_kwargs"]["iteration_num"] += 1
        if next_rate < kwargs["mode_kwargs"]["final_rate"]:
            next_rate = kwargs["mode_kwargs"]["final_rate"]
    else:
        raise ValueError(
            "If you want to implement a new schedule mode please modify the schedule function!"
        )

    if next_rate <= 0:
        next_rate = 0
        print(
            "Warning: Your algorithm has reached a point where the scheduled next epsilon or stepsize is zero! You may reconsider your choices of schedule!"
        )

    kwargs["current_rate"] = next_rate
    return kwargs


# Utils for train function


def generate_random_seed() -> int:
    """Generates a random seed for NumPy RNG."""
    return random.randint(0, 2**32 - 1)


def softmax(x: list, beta: float) -> list:
    """Compute softmax probability values for a list."""
    x = np.array(x, dtype=float)
    e_x = np.exp(x * beta)
    softmax_vals = e_x / np.sum(e_x)
    return [val.item() for val in softmax_vals]


def check_input_for_train(**kwargs: dict) -> None:
    """
    Validates the input parameters for the train function. Does not validate algo and env!

    Parameters:
    - kwargs (dict): The keyword arguments to be used in train.

    Raises:
    - ValueError: If any of the input parameters are invalid.
    - TypeError: If any of the input types are invalid.
    """
    allowed_arguments = [
        "algo",
        "algo_kwargs",
        "algo_special_logs",
        "algo_special_logs_kwargs",
        "env",
        "env_kwargs",
        "env_randomization",
        "env_randomization_kwargs",
        "env_randomization_schedule",
        "training_mode",
        "num_steps",
        "max_steps_per_epoch",
        "training_seed_schedule",
        "training_reseeding",
        "eval_freq",
        "eval_steps",
        "eval_seed_schedule",
        "eval_reseeding",
        "eval_policy_choice",
        "eval_policy_choice_kwargs",
        "bias_estimation",
        "focus_state_actions",
        "which_state_actions_focus",
        "correct_action_log",
        "correct_action_log_which",
        "correct_act_q_fct_mode",
        "correct_act_q_fct_mode_kwargs",
        "safe_mode",
        "progress",
        "measure_runtime",
    ]

    # Are all keywords allowed
    if isinstance(kwargs, dict):
        for key in kwargs.keys():
            if not key in allowed_arguments:
                raise ValueError(
                    f"The keyword {key} is not allowed for the train function!"
                )
    else:
        raise TypeError("The keyword arguments should be passed as a dictionary!")

    # Algo kwargs should be dictionary and not contain environment name, environment kwargs, special_logs_kwargs, or the seed or checks parameters
    if "algo_kwargs" in kwargs.keys():
        if isinstance(kwargs["algo_kwargs"], dict):
            if "env" in kwargs["algo_kwargs"].keys():
                raise ValueError(
                    "The Environment name should not be passed as an algorithm keyword argument but instead only as the env argument!"
                )
            elif "env_kwargs" in kwargs["algo_kwargs"].keys():
                raise ValueError(
                    "The Environment keyword arguments should not be passed as an algorithm keyword argument but instead only as the env_kwargs argument!"
                )
            elif "rng_seed" in kwargs["algo_kwargs"].keys():
                raise ValueError(
                    "The seed for the random number generator should not be passed as an algorithm keyword argument but instead for the training and evaluation cycles seperate schedules may be provided with the option to reseed each!"
                )
            elif "checks" in kwargs["algo_kwargs"].keys():
                raise ValueError(
                    "The checks parameter should not be passed as an algorithm keyword argument but instead you should use the safe_mode argument!"
                )
            elif "special_logs_kwargs" in ["algo_kwargs"].keys():
                raise ValueError(
                    "The special logs keyword arguments should not be passed as an algorithm keyword argument but instead you should use the algo_special_logs_kwargs argument!"
                )
        else:
            raise TypeError(
                "The Algorithm keyword arguments need to be passed as a dictionary!"
            )

    # algo_special_logs needs to be boolean
    if "algo_special_logs" in kwargs.keys():
        if not isinstance(kwargs["algo_special_logs"], bool):
            raise TypeError("The parameter algo_special_logs needs to be boolean!")

    # algo_special_logs_kwargs needs to be a dictionary
    if "algo_special_logs_kwargs" in kwargs.keys():
        if not isinstance(kwargs["algo_special_logs_kwargs"], dict):
            raise TypeError(
                "The parameter algo_special_logs_kwargs needs to be a dictionary!"
            )

    # Env kwargs should be dictionary and not contain the seed or checks parameters
    if "env_kwargs" in kwargs.keys():
        if isinstance(kwargs["env_kwargs"], dict):
            if "rng_seed" in kwargs["env_kwargs"].keys():
                raise ValueError(
                    "The seed for the random number generator should not be passed as an environment keyword argument but instead for the training and evaluation cycles seperate schedules may be provided with the option to reseed each!"
                )
            elif "checks" in kwargs["env_kwargs"].keys():
                raise ValueError(
                    "The checks parameter should not be passed as an environment keyword argument but instead you should use the safe_mode argument!"
                )
        else:
            raise TypeError(
                "The Environment keyword arguments need to be passed as a dictionary!"
            )

    # Env Randomization should be bool
    if "env_randomization" in kwargs.keys():
        if not isinstance(kwargs["env_randomization"], bool):
            raise TypeError("The parameter env_randomization needs to be boolean!")

    # Environment randomization seed schedule is list containing at least one seed and all seeds are in the range of possible seeds or -1 for random
    if "env_randomization_schedule" in kwargs.keys():
        if isinstance(kwargs["env_randomization_schedule"], list):
            for s in kwargs["env_randomization_schedule"]:
                if isinstance(s, int):
                    if not (0 <= s < 2**32 or s == -1):
                        raise ValueError(
                            f"The provided seed {s} in the environment randomization seed schedule list is not in the range of acceptable integer seeds and does not take the value -1 corresponding to a random seed!"
                        )
                else:
                    raise TypeError(
                        "The seeds in the environment randomization seed schedule list need to be integers!"
                    )
        else:
            raise TypeError(
                "The environment randomization seed schedule needs to be a list!"
            )

    # Environment randomization kwargs needs to be a dictionary
    if "env_randomization_kwargs" in kwargs.keys():
        if not isinstance(kwargs["env_randomization_kwargs"], dict):
            raise TypeError(
                "The environment randomization keyword arguments need to be passed in a dictionary!"
            )

    # Training mode should be either steps or epoch
    if "training_mode" in kwargs.keys():
        if isinstance(kwargs["training_mode"], str):
            if not (
                kwargs["training_mode"] == "steps" or kwargs["training_mode"] == "epoch"
            ):
                raise ValueError(
                    f"It seems that the training mode {kwargs['training_mode']} is not implemented!"
                )
        else:
            raise TypeError("The training mode needs to be a string!")

    # Number of steps is positive integer
    if "num_steps" in kwargs.keys():
        if isinstance(kwargs["num_steps"], int):
            if kwargs["num_steps"] <= 0:
                raise ValueError("The number of steps needs to be positive!")
        else:
            raise TypeError("The number of steps needs to be an integer!")

    # Number of maximum steps per epoch needs to be a positive integer or -1
    if "max_steps_per_epoch" in kwargs.keys():
        if isinstance(kwargs["max_steps_per_epoch"], int):
            if not (
                kwargs["max_steps_per_epoch"] > 0 or kwargs["max_steps_per_epoch"] == -1
            ):
                raise ValueError(
                    "The maximum number of steps per epoch needs to be either a positive integer or -1 in case no maximum is applied!"
                )
        else:
            raise TypeError(
                "The maximum number of steps per epoch needs to be an integer!"
            )

    # Training seed schedule is list containing at least one seed and all seeds are in the range of possible seeds or -1 for random
    if "training_seed_schedule" in kwargs.keys():
        if isinstance(kwargs["training_seed_schedule"], list):
            for s in kwargs["training_seed_schedule"]:
                if isinstance(s, int):
                    if not (0 <= s < 2**32 or s == -1):
                        raise ValueError(
                            f"The provided seed {s} in the training seed schedule list is not in the range of acceptable integer seeds and does not take the value -1 corresponding to a random seed!"
                        )
                else:
                    raise TypeError(
                        "The seeds in the training seed schedule list need to be integers!"
                    )
        else:
            raise TypeError("The training seed schedule needs to be a list!")

    # Training reseeding needs to be boolean
    if "training_reseeding" in kwargs.keys():
        if not isinstance(kwargs["training_reseeding"], bool):
            raise TypeError("The parameter training_reseeding needs to be boolean!")

    # Evaluation frequency needs to be a positive integer
    if "eval_freq" in kwargs.keys():
        if isinstance(kwargs["eval_freq"], int):
            if kwargs["eval_freq"] <= 0:
                raise ValueError(
                    "The evaluation frequency needs to be a positive integer!"
                )
        else:
            raise TypeError("The evaluation frequency needs to be an integer!")

    # Evaluation steps needs to be a positive integer
    if "eval_steps" in kwargs.keys():
        if isinstance(kwargs["eval_steps"], int):
            if kwargs["eval_steps"] <= 0:
                raise ValueError("The evaluation steps need to be a positive integer!")
        else:
            raise TypeError("The evaluation steps need to be an integer!")

    # Evaluation seed schedule is list containing at least one seed and all seeds are in the range of possible seeds or -1 for random
    if "eval_seed_schedule" in kwargs.keys():
        if isinstance(kwargs["eval_seed_schedule"], list):
            for s in kwargs["eval_seed_schedule"]:
                if isinstance(s, int):
                    if not (0 <= s < 2**32 or s == -1):
                        raise ValueError(
                            f"The provided seed {s} in the evaluation seed schedule list is not in the range of acceptable integer seeds and does not take the value -1 corresponding to a random seed!"
                        )
                else:
                    raise TypeError(
                        "The seeds in the evaluation seed schedule list need to be a integers!"
                    )
        else:
            raise TypeError("The evaluation seed schedule needs to be a list!")

    # Eval reseeding needs to be boolean
    if "eval_reseeding" in kwargs.keys():
        if not isinstance(kwargs["eval_reseeding"], bool):
            raise TypeError("The parameter eval_reseeding needs to be boolean!")

    # Eval policy choice is allowed
    if "eval_policy_choice" in kwargs.keys():
        if not (
            kwargs["eval_policy_choice"] == "greedy"
            or kwargs["eval_policy_choice"] == "softmax"
        ):
            raise ValueError("Eval policy choice needs to be either greedy or softmax!")

    # Eval policy choice kwargs are allowed
    if "eval_policy_choice_kwargs" in kwargs.keys():
        if isinstance(kwargs["eval_policy_choice_kwargs"], dict):
            if kwargs["eval_policy_choice"] == "greedy":
                if kwargs["eval_policy_choice_kwargs"] != {}:
                    raise ValueError(
                        "Policy choice mode greedy does not need keyword arguments!"
                    )
            elif kwargs["eval_policy_choice"] == "softmax":
                if not ("beta" in kwargs["eval_policy_choice_kwargs"].keys()):
                    raise ValueError(
                        "Policy choice mode softmax needs inverse temperature parameter beta!"
                    )
        else:
            raise TypeError(
                "Evaluation policy mode choice keyword arguments need to be passed in a dictionary!"
            )

    # Bias estimation needs to be boolean
    if "bias_estimation" in kwargs.keys():
        if not isinstance(kwargs["bias_estimation"], bool):
            raise TypeError("The parameter bias_estimation needs to be boolean!")

    # Choice of states and actions to log the bias estimation for should be a Tuple containing a list of states and a list containing lists of action numbers or "best"
    if "which_state_actions_focus" in kwargs.keys():
        if isinstance(kwargs["which_state_actions_focus"], tuple):
            if len(kwargs["which_state_actions_focus"]) == 2:
                if isinstance(
                    kwargs["which_state_actions_focus"][0], list
                ) and isinstance(kwargs["which_state_actions_focus"][1], list):
                    for state in kwargs["which_state_actions_focus"][0]:
                        if isinstance(state, int):
                            if state < 0:
                                raise ValueError(
                                    f"The state {state} you provided for logging individual bias estimations is not valid, since it is negative!"
                                )
                        elif state != "start":
                            raise TypeError(
                                f"The state {state} you provided for logging individual bias estimations is not valid since it is neither an integer nor 'start'!"
                            )
                    for actlist in kwargs["which_state_actions_focus"][1]:
                        if isinstance(actlist, list):
                            for act in actlist:
                                if isinstance(act, int):
                                    if act < 0:
                                        raise ValueError(
                                            f"The action {act} you provided for logging individual bias estimations is not valid since it is negative!"
                                        )
                                elif act != "best":
                                    raise ValueError(
                                        f"The action {act} you provided for logging individual bias estimations is not valid since it is neither an integer nor best!"
                                    )
                        else:
                            raise TypeError(
                                f"The list of actions {actlist} you provided for logging individual bias estimations is not valid since it is not a list!"
                            )
                else:
                    raise TypeError(
                        "The states and actions provided for logging individual bias estimations have to be provided in the form of a list!"
                    )
            else:
                raise ValueError(
                    "The states and actions provided for logging individual bias estimations need to be provided as a tuple of length two!"
                )
        else:
            raise ValueError(
                "The states and actions provided for logging individual bias estimations need to be provided as a tuple!"
            )

    # Focus state actions needs to be boolean
    if "focus_state_actions" in kwargs.keys():
        if not isinstance(kwargs["focus_state_actions"], bool):
            raise TypeError("The parameter focus_state_actions needs to be boolean!")

    # Correct action log needs to be boolean
    if "correct_action_log" in kwargs.keys():
        if not isinstance(kwargs["correct_action_log"], bool):
            raise TypeError("The parameter correct_action_log needs to be boolean!")

    # Correct action log which needs to be either all or a list of positive integers without doubling
    if "correct_action_log_which" in kwargs.keys():
        if kwargs["correct_action_log_which"] != "all":
            if isinstance(kwargs["correct_action_log_which"], list):
                if len(kwargs["correct_action_log_which"]) == len(
                    set(kwargs["correct_action_log_which"])
                ):
                    for state in kwargs["correct_action_log_which"]:
                        if isinstance(state, (float, int)):
                            if state < 0:
                                raise ValueError(
                                    "The states considered for the correct action rates need to be positive integers!"
                                )
                        else:
                            raise TypeError(
                                "The states considered for the correct action rates need to be integers!"
                            )
                else:
                    raise ValueError(
                        "The states considered for the correct action rates cannot contain duplicates!"
                    )
            else:
                raise ValueError(
                    "The states considered for the correct action rates need to be contained in a list!"
                )

    # Correct action and q function mode needs to be manual or value iteration. If manual, environment_randomization should be off if bias_estimation or correct_action_log or focus_state_actions are also on
    if "correct_act_q_fct_mode" in kwargs.keys():
        if (
            kwargs["correct_act_q_fct_mode"] == "manual"
            or kwargs["correct_act_q_fct_mode"] == "value_iteration"
        ):
            if "env_randomization" in kwargs.keys():
                if kwargs["env_randomization"]:
                    if "bias_estimation" in kwargs.keys():
                        if kwargs["bias_estimation"]:
                            if kwargs["correct_act_q_fct_mode"] == "manual":
                                raise ValueError(
                                    "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                )
                        elif "correct_action_log" in kwargs.keys():
                            if kwargs["correct_action_log"]:
                                if kwargs["correct_act_q_fct_mode"] == "manual":
                                    raise ValueError(
                                        "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                    )
                            else:
                                if "focus_state_actions" in kwargs.keys():
                                    if kwargs["focus_state_actions"]:
                                        if kwargs["correct_act_q_fct_mode"] == "manual":
                                            raise ValueError(
                                                "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                            )
                        elif "focus_state_actions" in kwargs.keys():
                            if kwargs["focus_state_actions"]:
                                if kwargs["correct_act_q_fct_mode"] == "manual":
                                    raise ValueError(
                                        "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                    )
                    elif "correct_action_log" in kwargs.keys():
                        if kwargs["correct_action_log"]:
                            if kwargs["correct_act_q_fct_mode"] == "manual":
                                raise ValueError(
                                    "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                )
                        else:
                            if "focus_state_actions" in kwargs.keys():
                                if kwargs["focus_state_actions"]:
                                    if kwargs["correct_act_q_fct_mode"] == "manual":
                                        raise ValueError(
                                            "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                        )
                    elif "focus_state_actions" in kwargs.keys():
                        if kwargs["focus_state_actions"]:
                            if kwargs["correct_act_q_fct_mode"] == "manual":
                                raise ValueError(
                                    "If Environment randomization is on and the correct_act_q_fct_mode will be computed, the correct action and q function mode is not allowed to be set to manual!"
                                )
        else:
            raise ValueError(
                "The correct action and q function mode you passed seems to not be valid. If you tried implementing a new one make sure to update the inputcheck function!"
            )

    # Correct action and q function mode keyword arguments need to match the chosen keyword
    if "correct_act_q_fct_mode_kwargs" in kwargs.keys():
        if "correct_act_q_fct_mode" in kwargs.keys():
            if isinstance(kwargs["correct_act_q_fct_mode_kwargs"], dict):
                if kwargs["correct_act_q_fct_mode"] == "manual":
                    necessary_kwargs = ["correct_actions", "correct_q_fct"]
                    for kwarg in necessary_kwargs:
                        if not kwarg in kwargs["correct_act_q_fct_mode_kwargs"].keys():
                            raise ValueError(
                                f"The keyword {kwarg} is missing for determining the correct action and q function!"
                            )
                    for kwarg in kwargs["correct_act_q_fct_mode_kwargs"].keys():
                        if not kwarg in necessary_kwargs:
                            raise ValueError(
                                f"The keyword {kwargs} is not allowed for determining the correct action and q function!"
                            )
                    if isinstance(
                        kwargs["correct_act_q_fct_mode_kwargs"]["correct_actions"], list
                    ):
                        for actlist in kwargs["correct_act_q_fct_mode_kwargs"][
                            "correct_actions"
                        ]:
                            if isinstance(actlist, list):
                                for act in actlist:
                                    if isinstance(act, int):
                                        if act < 0:
                                            raise ValueError(
                                                f"The action {act} passed in the dictionary of correct actions is wrong as it is negative!"
                                            )
                                    else:
                                        raise TypeError(
                                            f"The action {act} passed in the dictionary of correct actions is wrong as it is not an interger!"
                                        )
                            else:
                                raise TypeError(
                                    "The correct actions for the manual initialization of correct actions need to be passed in lists!"
                                )
                    else:
                        raise TypeError(
                            "The correct actions for the manual initialization of correct actions need to be passed as a list!"
                        )
                    if isinstance(
                        kwargs["correct_act_q_fct_mode_kwargs"]["correct_q_fct"], dict
                    ):
                        for key, val in kwargs["correct_act_q_fct_mode_kwargs"][
                            "correct_q_fct"
                        ].items():
                            if not isinstance(val, (int, float)):
                                raise TypeError(
                                    f"The q value {val} you passed for the manual initialization of the Q value is not allowed, as it is not numerical!"
                                )
                            if isinstance(key, tuple):
                                if len(key) == 2:
                                    if isinstance(key[0], int) and isinstance(
                                        key[1], int
                                    ):
                                        if key[0] < 0 or key[1] < 0:
                                            raise ValueError(
                                                f"The state action pair {key} contains a negative integer and is thus invalid!"
                                            )
                                    else:
                                        raise TypeError(
                                            f"The state action pair {key} contains non-integers and is thus invalid!"
                                        )
                                else:
                                    raise ValueError(
                                        f"The state action pair {key} is not a tuple of length 2 and is thus invalid!"
                                    )
                            else:
                                raise TypeError(
                                    f"The state action pair {key} is not a tuple and is thus invalid!"
                                )
                elif kwargs["correct_act_q_fct_mode"] == "value_iteration":
                    necessary_kwargs = [
                        "n_max",
                        "tol",
                        "env_mean_rewards",
                        "env_mean_rewards_mc_runs",
                    ]
                    for kwarg in necessary_kwargs:
                        if not kwarg in kwargs["correct_act_q_fct_mode_kwargs"].keys():
                            raise ValueError(
                                f"The keyword {kwarg} is missing for determining the correct action and q function!"
                            )
                    for kwarg in kwargs["correct_act_q_fct_mode_kwargs"].keys():
                        if not kwarg in necessary_kwargs:
                            raise ValueError(
                                f"The keyword {kwargs} is not allowed for determining the correct action and q function!"
                            )
                    if isinstance(
                        kwargs["correct_act_q_fct_mode_kwargs"]["n_max"], int
                    ):
                        if kwargs["correct_act_q_fct_mode_kwargs"]["n_max"] <= 0:
                            raise ValueError(
                                "The number of maximum iterations for the value iteration needs to be a positive number!"
                            )
                    else:
                        raise TypeError(
                            "The number of maximum iterations for the value iteration needs to be an integer!"
                        )
                    if isinstance(
                        kwargs["correct_act_q_fct_mode_kwargs"]["tol"], (int, float)
                    ):
                        if kwargs["correct_act_q_fct_mode_kwargs"]["tol"] < 0:
                            raise ValueError(
                                "The error tolerance for the value iteration needs to be positive!"
                            )
                    else:
                        raise TypeError(
                            "The error tolerance for the value iteration needs to be a numerical value!"
                        )
                else:
                    raise ValueError(
                        "The correct action and q value determination mode you chose seems to be wrong. If you tried implementing a new one, you should also update the check input function!"
                    )
            else:
                raise TypeError(
                    "The keyword arguments for the determination of the correct actions and q function need to be passed in a dictionary!"
                )
        else:
            raise ValueError(
                "You cannot specify keyword argument for the determination of action and q values without passing a determination mode!"
            )

    # Safe mode needs to be boolean
    if "safe_mode" in kwargs.keys():
        if not isinstance(kwargs["safe_mode"], bool):
            raise TypeError("The parameter safe_mode needs to be boolean!")

    # Progress needs to be boolean
    if "progress" in kwargs.keys():
        if not isinstance(kwargs["progress"], bool):
            raise TypeError("The parameter progress needs to be boolean!")

    # Measure runtime needs to be boolean
    if "measure_runtime" in kwargs.keys():
        if not isinstance(kwargs["measure_runtime"], bool):
            raise TypeError("The parameter measure_runtime needs to be boolean!")

    # Warning if focus_state_action and environment randomization randomization is on
    if "focus_state_actions" in kwargs.keys() and "env_randomization" in kwargs.keys():
        print(
            "Warning: randomizing the environment while trying to get Q or Bias Values for certain states might result in unreasonable values or even errors due to states not existing in certain occasions!"
        )

    return 1


# Utils for experiment manager


def is_lambda(obj):
    return isinstance(obj, types.LambdaType) and obj.__name__ == "<lambda>"


# Utils for plot functions


def check_input_for_results_single_to_batch_for_plot(
    result_paths: list[str],
    labels: list[str],
    output_folder: str,
    project_name: str,
    safe_mode: bool,
    conditional_plots: bool,
) -> None:
    """
    Validates the input parameters for the results_single_to_batch_for_plot function.

    Parameters:
    - result_paths (list): A list of paths to the folder in which the results.pkl files to be used can be found.
    - labels (str): A list of labels corresponding to the runs in the reslut_paths list. If there are not enough labels in the list to match
      all results in result path, the rest of the results will be assigned their respective paths as label.
    - output_folder (str): The folder in which the aggregated results should be saved.
    - project_name (str): The project name under which the aggregated results should be saved.
    - safe_mode (bool): If True, a check will be performed on the inputs. Additionally, it will be checked if the runs that will be aggregated
      are comparable in terms of parameters given to the execute_experiment function.

    Raises:
    - ValueError: If any of the input parameters are invalid.
    - TypeError: If any of the input types are invalid.
    """

    # conditional_plots needs to be boolean
    if not isinstance(conditional_plots, bool):
        raise TypeError("The parameter conditional_plots needs to be boolean!")

    # results_paths needs to be a list containing valid paths to folders containing the necessary files
    if isinstance(result_paths, list):
        if len(result_paths) == len(set(result_paths)):
            for path in result_paths:
                if isinstance(path, str):
                    if os.path.exists(path):
                        if not os.path.exists(os.path.join(path, "results.pkl")):
                            raise ValueError(
                                f"The path {path} you provided contains no results file!"
                            )
                        if os.path.exists(os.path.join(path, "reproduce_run")):
                            if not os.path.exists(
                                os.path.join(
                                    os.path.join(path, "reproduce_run"),
                                    "parameters.yaml",
                                )
                            ):
                                raise ValueError(
                                    f"The path {path} you provided contains no parameters file!"
                                )
                            if conditional_plots:
                                if not os.path.exists(
                                    os.path.join(
                                        path, "correct_policy_and_q_function.txt"
                                    )
                                ):
                                    raise ValueError(
                                        f"The path {path} you provided contains no file containing the correct policy and Q function even though you want to plot plots requiring conditional statements that need them!"
                                    )
                        else:
                            raise ValueError(
                                f"The path {path} you provided contains no reproduce_run folder!"
                            )
                    else:
                        raise ValueError(
                            f"The path {path} you provided does not exist!"
                        )
                else:
                    raise TypeError(
                        f"The path {path} you provided is no string and thus invalid!"
                    )
        else:
            raise ValueError("The list of paths you provided contains doubles!")
    else:
        raise TypeError("The paths in result_paths need to be passed in a list!")

    # labels needs to be a list containing strings that are unique, match none of the results_paths and the list is not longer than result_paths
    if isinstance(labels, list):
        if len(labels) == len(set(labels)):
            if len(labels) <= len(result_paths):
                for label in labels:
                    if isinstance(label, str):
                        if label in result_paths:
                            raise ValueError(
                                f"The label {label} matches one of the result paths and is thus not allowed!"
                            )
                    else:
                        raise TypeError(
                            f"The label {label} is not a string and thus invalid!"
                        )
            else:
                raise ValueError(
                    "The list of labels is longer than the list of given paths to results!"
                )
        else:
            raise ValueError("The list of labels you provided contains doubles!")
    else:
        raise TypeError("The labels in labels need to be passed in a list!")

    # output_folder needs to be a valid path
    if isinstance(output_folder, str):
        if not os.path.exists(output_folder):
            raise ValueError(
                f"The path {output_folder} you provided for saving the aggregated results does not exist!"
            )
    else:
        raise TypeError("The path to the output folder needs to be a string!")

    # project_name needs to be a string
    if not isinstance(project_name, str):
        raise TypeError("The project name needs to be a string!")

    # safe_mode needs to be boolean
    if not isinstance(safe_mode, bool):
        raise TypeError("The parameter safe_mode needs to be boolean!")


def check_input_for_single_plot_fct(
    input_path: str,
    plot_folder: str,
    project_name: str,
    figsize: tuple[Union[int, float]],
    loc: Union[str, int],
    grid: bool,
    show: bool,
    save: bool,
    mode: Any,
    safe_mode: bool,
) -> None:
    """
    Validates the input parameters for the different plot functions.

    Parameters:
    - input_path (str): The input path where the aggregated results file is located. Can be passed as None. In this case the path will be
      constructed from the given plot folder and project name.
    - plot_folder (str): The folder to which the plots should be saved. If no input path was given, simultaneously the folder in which the
      results to be plotted are located.
    - project_name (str): The project name under which the plots should be saved. If no input path was given, simultaneously the file name in
      which the results to be plotted are located.
    - figsize (tuple): A tuple of integers or float, specifying the width and height of the plot in inches.
    - loc (str): The location of the legend.
    - grid (bool): If True, the plot will exhibit a grid.
    - show (bool): If True, the plot will be shown.
    - save (bool): If True, the plot will be saved as a .png file.
    - mode (any): The mode for the plot. Can either be 'single plot', meaning the plot will be treated as a single plot, or can be a tuple
      consisting of 'multiple plots' and an axis that is passed, meaning it should be a subplot on a specified axis ax (This is used in
      different functions to unify plots with different data or plot all metrics in one figure). In the latter case save and show must be
      turned off.
    - safe_mode (bool): If True, a parameter check will be performed.

    Raises:
    - ValueError: If any of the input parameters are invalid.
    - TypeError: If any of the input types are invalid.
    """

    # Input path needs to be None or a valid path to a pickle file
    if input_path != None:
        if isinstance(input_path, str):
            if os.path.isfile(input_path):
                if not input_path.endswith((".pkl", ".pickle")):
                    raise ValueError(
                        f"The given input path {input_path} does not point to a pickle file!"
                    )
            else:
                raise ValueError(
                    f"The given input path {input_path} does not point to a file!"
                )
        else:
            raise TypeError("The given input path is no string and thus invalid!")

    # Plot folder needs to be a string and point to an existing directory
    if isinstance(plot_folder, str):
        if not os.path.exists(plot_folder):
            raise ValueError(f"The given plot folder {plot_folder} does not exist!")
    else:
        raise TypeError("The given plot folder is no string and thus invalid!")

    # Project name needs to be a string
    if not isinstance(project_name, str):
        raise TypeError("The project name needs to be a string!")

    # If input path is none, the combination of plot folder and project name needs to lead to a pickle file
    if input_path == None:
        if not (
            os.path.exists(os.path.join(plot_folder, project_name + ".pkl"))
            or os.path.exists(os.path.join(plot_folder, project_name + ".pickle"))
        ):
            raise ValueError(
                "In case the input path was not specified, the plot folder and project name need to specify the location of a pickle file!"
            )

    # Fig size needs to be tuple of numerical values
    if isinstance(figsize, tuple):
        if len(figsize) == 2:
            if not (
                isinstance(figsize[0], (int, float))
                and isinstance(figsize[1], (int, float))
            ):
                raise ValueError(
                    "The figure width and height need to be specified by numerical values!"
                )
        else:
            raise TypeError(
                "The figure width and height need to be specified by a tuple of length two!"
            )
    else:
        raise TypeError("The figure size needs to be contained in a tuple!")

    # Location of the legend needs to be an allowed string or int in the range
    if isinstance(loc, str):
        allowed_locs = [
            "best",
            "upper right",
            "upper left",
            "lower right",
            "lower left",
            "center",
            "center left",
            "center right",
            "upper center",
            "lower center",
            "right",
        ]
        if not (loc in allowed_locs):
            raise ValueError(f"The location {loc} of the legend is not allowed!")
    elif isinstance(loc, int):
        if not (0 <= loc <= 10):
            raise ValueError(
                "The location of the legend can only be an integer between 0 and 10!"
            )
    else:
        raise TypeError(
            "The location of the legend needs to either be an integer or a string!"
        )

    # grid needs to be boolean
    if not isinstance(grid, bool):
        raise TypeError("Parameter grid needs to be boolean!")

    # show needs to be boolean
    if not isinstance(show, bool):
        raise TypeError("Parameter show needs to be boolean!")

    # save needs to be boolean
    if not isinstance(save, bool):
        raise TypeError("Parameter save needs to be boolean!")

    # mode needs to be either "single plot" or tuple of "multiple plot", matplotlib axis, and boolean. If the latter, both show and save need to be disabled
    if mode != "single plot":
        if isinstance(mode, tuple):
            if len(mode) == 3:
                if not (
                    mode[0] == "multiple plots"
                    and isinstance(mode[1], matplotlib.axes.Axes)
                    and (isinstance(mode[2], str) or mode[2] == None)
                ):
                    raise ValueError(
                        "The mode can either be 'single plot' or a tuple consisting of 'multiple plots', a matplotlib axis, and a title!"
                    )
            else:
                raise ValueError(
                    "The mode can either be 'single plot' or a tuple consisting of 'multiple plots', a matplotlib axis, and a title!"
                )
        else:
            raise ValueError(
                "The mode can either be 'single plot' or a tuple consisting of 'multiple plots', a matplotlib axis, and a title!"
            )

    # safe_mode needs to be boolean
    if not isinstance(safe_mode, bool):
        raise TypeError("Parameter safe_mode needs to be boolean!")


def check_input_for_plot_fct_one(
    input_path: str,
    plot_folder: str,
    project_name: str,
    show: bool,
    save: bool,
    save_format: str,
    safe_mode: bool,
    mode: Any,
    plot_key: str,
    plot_key_necessary_kwargs: Dict,
    plot_ci: bool,
    ci: float,
    figsize: Tuple[Union[int, float]],
    fontsizes: Tuple[int, int],
    dpi: int,
    further_line_configs: Dict,
    grid: bool,
    further_grid_configs: Dict,
    loc: Union[str, int],
) -> None:

    allowed_plotkeys = [
        "Num times timesteps reached",
        "Mean rewards at steps",
        "Num times epochs reached",
        "Mean scores at epochs",
        "Mean correct action rates at epochs",
        "Mean durations of epochs",
        "Percent of capped epochs",
        "Num times eval times reached",
        "Mean scores at evals",
        "Mean correct action rates at evals",
        "Mean correct action rates at chosen at evals",
        "Mean biases at chosen at evals",
        "Mean Q function values at chosen at evals",
        "Mean bias metrics at evals",
        "Mean termination rates at evals",
        "Mean lengths at evals",
        "Mean policy values at eval",
        "Runtimes",
        "Mean special logs at steps",
        "Mean special logs at epochs",
        "Mean special logs at evals",
    ]
    allowed_locs = [
        "best",
        "upper right",
        "upper left",
        "lower right",
        "lower left",
        "center",
        "center left",
        "center right",
        "upper center",
        "lower center",
        "right",
    ]
    needing_kwargs = [
        "Percent of capped epochs",
        "Mean biases at chosen at evals",
        "Mean Q function values at chosen at evals",
        "Mean bias metrics at evals",
        "Mean policy values at eval",
        "Mean special logs at steps",
        "Mean special logs at epochs",
        "Mean special logs at evals",
    ]

    # Input path needs to be None or a valid path to a pickle file
    if input_path != None:
        if isinstance(input_path, str):
            if os.path.isfile(input_path):
                if not input_path.endswith((".pkl", ".pickle")):
                    raise ValueError(
                        f"The given input path {input_path} does not point to a pickle file!"
                    )
            else:
                raise ValueError(
                    f"The given input path {input_path} does not point to a file!"
                )
        else:
            raise TypeError("The given input path is no string and thus invalid!")

    # Plot folder needs to be a string and point to an existing directory
    if isinstance(plot_folder, str):
        if not os.path.exists(plot_folder):
            raise ValueError(f"The given plot folder {plot_folder} does not exist!")
    else:
        raise TypeError("The given plot folder is no string and thus invalid!")

    # Project name needs to be a string
    if not isinstance(project_name, str):
        raise TypeError("The project name needs to be a string!")

    # show needs to be boolean
    if not isinstance(show, bool):
        raise TypeError("Parameter show needs to be boolean!")

    # save needs to be boolean
    if not isinstance(save, bool):
        raise TypeError("Parameter save needs to be boolean!")

    # save_format needs to be string
    if not isinstance(save_format, str):
        raise TypeError("Parameter save_format needs to be string!")

    # safe_mode needs to be boolean
    if not isinstance(safe_mode, bool):
        raise TypeError("Parameter safe_mode needs to be boolean!")

    # mode needs to be either "single plot" or tuple of "multiple plot", matplotlib axis, and boolean. If the latter, both show and save need to be disabled
    if mode != "single plot":
        if isinstance(mode, tuple):
            if len(mode) == 3:
                if not (
                    mode[0] == "multiple plots"
                    and isinstance(mode[1], matplotlib.axes.Axes)
                    and (isinstance(mode[2], str) or mode[2] == None)
                ):
                    raise ValueError(
                        "The mode can either be 'single plot' or a tuple consisting of 'multiple plots', a matplotlib axis, and a title!"
                    )
            else:
                raise ValueError(
                    "The mode can either be 'single plot' or a tuple consisting of 'multiple plots', a matplotlib axis, and a title!"
                )
        else:
            raise ValueError(
                "The mode can either be 'single plot' or a tuple consisting of 'multiple plots', a matplotlib axis, and a title!"
            )

    # plot_key needs to be a valid plot key
    if isinstance(plot_key, str):
        if not (plot_key in allowed_plotkeys):
            raise ValueError(f"Plot key {plot_key} is invalid!")
    else:
        raise TypeError("Parameter plot_key needs to be a string!")

    # plot_key_necessary_kwargs must be either Dict or None and contain the correct stuff
    if plot_key_necessary_kwargs is not None:
        if isinstance(plot_key_necessary_kwargs, Dict):
            if plot_key == "Percent of capped epochs":
                if "max_steps_per_epoch" in plot_key_necessary_kwargs.keys():
                    if isinstance(
                        plot_key_necessary_kwargs["max_steps_per_epoch"], int
                    ):
                        if plot_key_necessary_kwargs["max_steps_per_epoch"] <= 0:
                            print(
                                "Warning: If max_steps_per_epoch is not positive no episode was capped and therefore the plot corresponding to the plot_key Percent of capped epochs can not be plotted!"
                            )
                    else:
                        raise TypeError(
                            "The parameter max_steps_per_epoch needs to be an integer!"
                        )
                else:
                    raise ValueError(
                        "If plot_key is Percent of capped epochs the parameter max_steps_per_epoch needs to be passed!"
                    )
            elif (
                plot_key == "Mean biases at chosen at evals"
                or plot_key == "Mean Q function values at chosen at evals"
            ):
                if "which" in plot_key_necessary_kwargs.keys():
                    if isinstance(plot_key_necessary_kwargs["which"], tuple):
                        if len(plot_key_necessary_kwargs["which"]) == 2:
                            if not (
                                (
                                    isinstance(
                                        plot_key_necessary_kwargs["which"][0], int
                                    )
                                    or plot_key_necessary_kwargs["which"][0] == "start"
                                )
                                and (
                                    isinstance(
                                        plot_key_necessary_kwargs["which"][1], int
                                    )
                                    or plot_key_necessary_kwargs["which"][1] == "best"
                                )
                            ):
                                raise ValueError(
                                    "The state action pair whose bias or Q function value should be plotted needs to be given as a tuple of integer or 'start' and integer or 'best'!"
                                )
                        else:
                            raise ValueError(
                                "The state action pair whose bias or Q function value should be plotted needs to be given as a tuple of integer or 'start' and integer or 'best'!"
                            )
                    else:
                        raise TypeError(
                            "The state action pair whose bias or Q function value should be plotted needs to be given as a tuple of integer or 'start' and integer or 'best'!"
                        )
                else:
                    raise ValueError(
                        "If plot_key is Mean biases at chosen at evals or Mean Q function values at chosen at evals the parameter which needs to be passed!"
                    )
            elif plot_key == "Mean bias metrics at evals":
                if (
                    "sqared" in plot_key_necessary_kwargs.keys()
                    and "normalized" in plot_key_necessary_kwargs.keys()
                    and "best_arms" in plot_key_necessary_kwargs.keys()
                ):
                    if not isinstance(plot_key_necessary_kwargs["squared"], bool):
                        raise TypeError("Parameter squared needs to be boolean!")
                    if not isinstance(plot_key_necessary_kwargs["normalized"], bool):
                        raise TypeError("Parameter normalized needs to be boolean!")
                    if not isinstance(plot_key_necessary_kwargs["best_arms"], bool):
                        raise TypeError("Parameter best_arms needs to be boolean!")
            elif plot_key == "Mean policy values at eval":
                if "which" in plot_key_necessary_kwargs.keys():
                    if isinstance(plot_key_necessary_kwargs["which"], tuple):
                        if len(plot_key_necessary_kwargs["which"]) == 2:
                            if not (
                                isinstance(plot_key_necessary_kwargs["which"][0], int)
                                and isinstance(
                                    plot_key_necessary_kwargs["which"][1], int
                                )
                            ):
                                raise ValueError(
                                    "The state action pair whose policy values should be plotted needs to be given as a tuple of integers!"
                                )
                        else:
                            raise ValueError(
                                "The state action pair whose policy values should be plotted needs to be given as a tuple of integers of length two!"
                            )
                    else:
                        raise TypeError(
                            "The state action pair whose policy values should be plotted needs to be given as a tuple of integers!"
                        )
                else:
                    raise ValueError(
                        "If plot_key is Mean policy values at eval the parameter which needs to be passed!"
                    )
            elif (
                plot_key == "Mean special logs at steps"
                or plot_key == "Mean special logs at epochs"
                or plot_key == "Mean special logs at evals"
            ):
                if "index" in plot_key_necessary_kwargs.keys():
                    if isinstance(plot_key_necessary_kwargs["index"], int):
                        if plot_key_necessary_kwargs["index"] > 0:
                            if "real_value" in plot_key_necessary_kwargs.keys():
                                if isinstance(
                                    plot_key_necessary_kwargs["real_value"],
                                    (int, float),
                                ):
                                    if (
                                        "real_value_label"
                                        in plot_key_necessary_kwargs.keys()
                                    ):
                                        if not isinstance(
                                            plot_key_necessary_kwargs[
                                                "real_value_label"
                                            ],
                                            str,
                                        ):
                                            raise TypeError(
                                                "If a label for the real value for the special plots is provided it needs to be a string!"
                                            )
                                else:
                                    raise TypeError(
                                        "If a real value for the special plots is provided it needs ot be numerical!"
                                    )
                            if "y_label" in plot_key_necessary_kwargs.keys():
                                if not isinstance(
                                    plot_key_necessary_kwargs["y_label"], str
                                ):
                                    raise TypeError(
                                        "If a y label for the special plots is provided it needs to be a string!"
                                    )
                        else:
                            raise ValueError(
                                "Index for special plots must be positive integer!"
                            )
                    else:
                        raise TypeError(
                            "Index for special plots must be positive integer!"
                        )
                else:
                    raise ValueError(
                        "If special plots should be plotted, an index must be provided!"
                    )
        else:
            raise TypeError(
                "Parameter plot_key_necessary_kwargs needs to be a dictionary!"
            )
    else:
        if plot_key in needing_kwargs:
            raise ValueError(
                f"The plot_key {plot_key} needs additional keyword arguments!"
            )

    # plot_ci needs to be boolean
    if not isinstance(plot_ci, bool):
        raise TypeError("Parameter plot_ci needs to be boolean!")

    # ci needs to be float between 0 and 1
    if isinstance(ci, float):
        if not (0 < ci < 1):
            raise ValueError("Parameter ci needs to be between 0 and 1!")
    else:
        raise TypeError("Parameter ci needs to be a float between 0 and 1!")

    # Fig size needs to be tuple of numerical values
    if isinstance(figsize, tuple):
        if len(figsize) == 2:
            if not (
                isinstance(figsize[0], (int, float))
                and isinstance(figsize[1], (int, float))
            ):
                raise ValueError(
                    "The figure width and height need to be specified by numerical values!"
                )
        else:
            raise TypeError(
                "The figure width and height need to be specified by a tuple of length two!"
            )
    else:
        raise TypeError("The figure size needs to be contained in a tuple!")

    # Fontsizes needs to be tuple of numerical values
    if isinstance(fontsizes, tuple):
        if len(fontsizes) == 2:
            if isinstance(fontsizes[0], int) and isinstance(fontsizes[1], int):
                if not (0 < fontsizes[0] and 0 < fontsizes[1]):
                    raise ValueError(
                        "The fontsizes for title and axes need to be specified by positive integers!"
                    )
            else:
                raise ValueError(
                    "The fontsizes for title and axes need to be specified by integers!"
                )
        else:
            raise TypeError(
                "The fontsizes for title and axes need to be specified by a tuple of length two!"
            )
    else:
        raise TypeError("The fontsizes need to be contained in a tuple!")

    # dpi needs to be positive integer
    if isinstance(dpi, int):
        if not (0 < dpi):
            raise ValueError("The resolution needs to be a positive integer!")
    else:
        raise TypeError("The resolution needs to be a positive integer!")

    # further_line_configs needs to be dictionary mapping to lists of same length
    if isinstance(further_line_configs, dict):
        len_to_be = 0
        for i, key in enumerate(further_line_configs.keys()):
            if isinstance(key, str):
                if i == 0:
                    if isinstance(further_line_configs[key], list):
                        len_to_be = len(further_line_configs[key])
                    else:
                        raise TypeError(
                            "The keys of further_line_configs need to map to lists containing the configurations!"
                        )
                else:
                    if isinstance(further_line_configs[key], list):
                        if not (len(further_line_configs[key]) == len_to_be):
                            raise ValueError(
                                "The keys of further_line_configs need to map to lists of equal length containing the configurations!"
                            )
                    else:
                        raise TypeError(
                            "The keys of further_line_configs need to map to lists containing the configurations!"
                        )
            else:
                raise ValueError("The keys of further_line_configs need to be strings!")
    else:
        raise ValueError("Parameter further_line_configs needs to be a dictionary!")

    # grid needs to be boolean
    if not isinstance(grid, bool):
        raise TypeError("Parameter grid needs to be boolean!")

    # further_grid_configs needs to be dictionary
    if isinstance(further_grid_configs, dict):
        for key in further_grid_configs.keys():
            if not isinstance(key, str):
                raise ValueError("The keys of further_grid_configs need to be strings!")
    else:
        raise TypeError("Parameter further_grid_configs needs to be a dictionary!")

    # Location of the legend needs to be an allowed string or int in the range
    if isinstance(loc, str):
        if not (loc in allowed_locs):
            raise ValueError(f"The location {loc} of the legend is not allowed!")
    elif isinstance(loc, int):
        if not (0 <= loc <= 10):
            raise ValueError(
                "The location of the legend can only be an integer between 0 and 10!"
            )
    else:
        raise TypeError(
            "The location of the legend needs to either be an integer or a string!"
        )


def parse_correct_policy_and_q_function(file_path: str) -> tuple[list, dict]:
    policy_list = []
    q_function_dict = {}

    with open(file_path, "r") as file:
        lines = file.readlines()

    # Parse the "Estimated correct Policy" section
    policy_start = lines.index("Estimated correct Policy:\n") + 2
    for line in lines[policy_start:]:
        if line.strip() == "":  # Stop at the empty line
            break
        _, values = line.split(":")
        policy_list.append(eval(values.strip()))

    # Parse the "Estimated correct Q Function" section
    q_function_start = lines.index("Estimated correct Q Function:\n") + 1
    for line in lines[q_function_start:]:
        if line.strip():  # Ignore empty lines
            key, value = line.split(":")
            q_function_dict[eval(key.strip())] = float(value.strip())

    return policy_list, q_function_dict
