from typing import Dict, Any, Union, List, Tuple
from copy import deepcopy

import envs
from algos.algo import Algo
from algos.policy import Policy, BasePolicy
from utils import schedule, check_for_schedule_allowed, sample_from_dist


class ADP(Algo):

    def __init__(
        self,
        env: envs.Env = None,
        env_kwargs: Dict = None,
        policy: Policy = None,
        policy_kwargs: Dict = None,
        learning_rate_kwargs: Dict[str, Any] = None,
        learning_rate_state_action_wise: bool = False,
        gamma: Union[int, float] = 0.99,
        q_fct_manual_init: bool = False,
        initial_q_fct: Dict = None,
        cycle_lengths: Union[List[int], Dict[Tuple[int, int], List[int]], None] = None,
        special_logs_kwargs: Dict = None,
        rng_seed=42,
        randomization_seed: int = 42,
        checks: str = "all_checks",
        uniform_state_action_sampling: bool = False,
        lr_per_cycle: bool = False,
        adaptive_sync: bool = False,
        adaptive_eps: Union[float, List[float]] = 0.1,
        adaptive_hysteresis_k: int = 1,
        adaptive_min_cycle_steps: Union[int, Dict[Tuple[int, int], int]] = 1,
        no_target: bool = False,
    ) -> None:
        """
        Initializes the Approximate Dynamic Programming algorithm. The environment and policy along with their arguments are used to model
        the Markov decision process. The learning rate schedule mode and if it should be applied statewise or not is
        passed. Optionally, a manual initialization of the Q function is possible.
        """
        self.lr_per_cycle = lr_per_cycle
        self.no_target = no_target
        self.uniform_state_action_sampling = uniform_state_action_sampling
        self.adaptive_sync = adaptive_sync
        self.adaptive_eps = adaptive_eps
        self.adaptive_min_cycle_steps = adaptive_min_cycle_steps
        self.adaptive_hysteresis_k = adaptive_hysteresis_k

        # Default arguments if arguments are None
        if env is None:
            env = envs.GridWorld
        if env_kwargs is None:
            env_kwargs = {}
        if policy is None:
            policy = BasePolicy
        if policy_kwargs is None:
            policy_kwargs = {}
        self.cycle_lengths = cycle_lengths
        # Only perform basic Q if cycle_lengths is None
        if self.cycle_lengths is None:
            self.perform_basic_q = True
        else:
            self.perform_basic_q = False

        if learning_rate_kwargs is None:
            learning_rate_kwargs = {
                "initial_rate": 0.1,
                "mode": "rate",
                "mode_kwargs": {
                    "rate_fct": lambda n: 0.1 / (n + 1),
                    "iteration_num": 1,
                    "final_rate": 0,
                },
            }

        if initial_q_fct is None:
            initial_q_fct = {}
        if special_logs_kwargs is None:
            special_logs_kwargs = {}

        # Standard initialization of arguments from the input
        self.env = env
        self.env_kwargs = env_kwargs
        self.policy = policy
        self.policy_kwargs = policy_kwargs
        self.learning_rate_kwargs = deepcopy(learning_rate_kwargs)
        self.learning_rate_state_action_wise = learning_rate_state_action_wise
        self.gamma = gamma
        self.q_fct_manual_init = q_fct_manual_init
        self.rng_seed = rng_seed
        self.randomization_seed = randomization_seed
        self.initial_q_fct = deepcopy(initial_q_fct)
        self.special_logs_kwargs = special_logs_kwargs
        if isinstance(checks, str):
            if not (
                checks == "all_checks"
                or checks == "no_checks"
                or checks == "only_initial_checks"
            ):
                raise ValueError(
                    "Checks needs to be either all_checks, no_checks, or only_initial_checks"
                )
        else:
            raise TypeError("Flag for type and value checking must be a string!")
        self.checks = checks
        self.allowed_special_logs_kwargs_keys = [
            "updated_q_values",
            "which_updated_q_values",
            "last_gradient_values",
            "which_last_gradient_values",
            "last_mean_m_hat",
            "ema_values",
            "which_ema_values",
        ]

        if self.checks != "no_checks":
            self.inputcheck()

        try:
            self.env = self.env(
                rng_seed=self.rng_seed,
                randomization_seed=self.randomization_seed,
                **self.env_kwargs,
            )
        except (TypeError, ValueError) as e:
            print(
                "The environment kwargs you provided seem to be faulty. The environment was instead created with default parameters!"
            )
            print(f"Error message: {e}")
            self.env = env()
        try:
            self.policy = self.policy(
                rng_seed=self.rng_seed,
                env_allowed_actions=self.env.allowed_actions,
                env_num_states=self.env.num_states,
                env_num_actions=self.env.num_actions,
                **self.policy_kwargs,
            )
        except (TypeError, ValueError) as e:
            self.policy = self.policy(
                env_allowed_actions=self.env.allowed_actions,
                env_num_states=self.env.num_states,
                env_num_actions=self.env.num_actions,
            )
            print(
                "The policy kwargs you provided seem to be faulty. The policy was instead created with default parameters!"
            )
            print(f"Error message: {e}")

        # Initialize Qs: q_fct_update (updated each step) and q_fct (target/measured)
        if self.q_fct_manual_init:
            if self.checks != "no_checks":
                self.length_Q = 0
                for state in range(self.env.num_states):
                    self.length_Q += len(self.env.allowed_actions[state])
                if len(self.initial_q_fct) == self.length_Q:
                    for key in self.initial_q_fct.keys():
                        if isinstance(key, tuple):
                            if len(key) == 2:
                                if isinstance(key[0], int) and isinstance(key[1], int):
                                    if key[1] in self.env.allowed_actions[key[0]]:
                                        if not isinstance(
                                            self.initial_q_fct[key], (int, float)
                                        ):
                                            raise ValueError(
                                                "The given Q function values need to be numerical!"
                                            )
                                    else:
                                        raise ValueError(
                                            f"Action {key[1]} not allowed in state {key[0]}!"
                                        )
                                else:
                                    raise TypeError(
                                        "State and actions in the keys of the given Q function need to be integers!"
                                    )
                            else:
                                raise ValueError(
                                    "Keys of the given Q function need to be state action tuples of length 2!"
                                )
                        else:
                            raise TypeError(
                                "Keys of the given Q function need to be state action tuples!"
                            )
                else:
                    raise TypeError("The given Q function misses some entries!")
            self.q_fct_update = deepcopy(self.initial_q_fct)
        else:
            self.q_fct_update = {
                (state, action): 0
                for state in range(self.env.num_states)
                for action in self.env.allowed_actions[state]
            }
        # The measured/target Q function used for bias metrics
        self.q_fct = self.q_fct_update.copy()
        # Track the most recent abs_update per (s,a) for logging at eval
        self.last_abs_update = {
            (state, action): 0.0
            for state in range(self.env.num_states)
            for action in self.env.allowed_actions[state]
        }

        # Validate adaptive_msin_cycle_steps semantics now that (s,a) space is known
        if not self.perform_basic_q:
            if isinstance(self.cycle_lengths, list):
                # Global mode: adaptive_min_cycle_steps must be int >= 1
                if (
                    not isinstance(self.adaptive_min_cycle_steps, int)
                    or self.adaptive_min_cycle_steps < 1
                ):
                    raise ValueError(
                        "With global cycle_lengths (list), adaptive_min_cycle_steps must be an integer >= 1 representing global updates per cycle."
                    )
            elif isinstance(self.cycle_lengths, dict):
                # SA-wise mode: adaptive_min_cycle_steps must be dict mapping each (s,a) to an int >= 1
                if not isinstance(self.adaptive_min_cycle_steps, dict):
                    raise ValueError(
                        "With state–action-wise cycle_lengths (dict), adaptive_min_cycle_steps must be a dict mapping (s,a) -> min updates before adaptive cut."
                    )
                required_keys = set(self.q_fct_update.keys())
                provided_keys = set(self.adaptive_min_cycle_steps.keys())
                if required_keys != provided_keys:
                    missing = required_keys - provided_keys
                    extra = provided_keys - required_keys
                    raise ValueError(
                        f"adaptive_min_cycle_steps keys must match all (s,a) pairs. Missing: {missing if missing else set()} | Extra: {extra if extra else set()}"
                    )
                for kk, vv in self.adaptive_min_cycle_steps.items():
                    if not isinstance(vv, int) or vv < 1:
                        raise ValueError(
                            "All adaptive_min_cycle_steps values must be integers >= 1."
                        )

        # Precompute all allowed (state, action) pairs for optional iid sampling
        self.all_state_action_pairs = [
            (int(state), int(action))
            for state in range(self.env.num_states)
            for action in self.env.allowed_actions[state]
        ]
        self._num_sa_pairs = len(self.all_state_action_pairs)
        if self.uniform_state_action_sampling and self._num_sa_pairs > 0:
            self._uniform_sa_probs = [1 / self._num_sa_pairs] * self._num_sa_pairs
        else:
            self._uniform_sa_probs = None

        if not self.perform_basic_q:
            # Determine update mode from cycle_lengths: global or state-action-wise
            if isinstance(self.cycle_lengths, list):
                # Global periodic update mode
                self.is_state_action_wise = False
                self.cycle_lengths_list = self.cycle_lengths
                self.global_step_counter = 0
                self.global_cycle_index = 0
                self.global_cycle_step_counter = 0
                self.next_global_cycle_threshold = self.cycle_lengths_list[0]
                # (Removed global-specific adaptive trackers)
            elif isinstance(self.cycle_lengths, dict):
                # State-action-wise periodic update mode
                self.is_state_action_wise = True
                self.cycle_lengths_dict = deepcopy(self.cycle_lengths)
                self.sa_update_counters = {k: 0 for k in self.q_fct_update}
                self.sa_cycle_indices = {k: 0 for k in self.q_fct_update}
                # For each (s,a), set next threshold to its first cycle length
                self.sa_next_cycle_thresholds = {
                    k: (
                        self.cycle_lengths_dict[k][0]
                        if (
                            k in self.cycle_lengths_dict
                            and len(self.cycle_lengths_dict[k]) > 0
                        )
                        else 1
                    )
                    for k in self.q_fct_update
                }
                # (Removed SA-specific adaptive trackers)
            else:
                raise ValueError("cycle_lengths must be a list or a dict.")

        # Normalize adaptive_eps depending on mode and initialize current epsilon per cycle
        if not self.perform_basic_q:
            if isinstance(self.cycle_lengths, list):
                # Global mode: allow a list schedule like cycle_lengths
                if isinstance(self.adaptive_eps, list):
                    if not self.adaptive_eps:
                        raise ValueError(
                            "In global mode, adaptive_eps list must be non-empty."
                        )
                    for v in self.adaptive_eps:
                        if not isinstance(v, (int, float)) or v < 0:
                            raise ValueError(
                                "In global mode, adaptive_eps list must contain non-negative numbers."
                            )
                    self.global_eps_list = [float(v) for v in self.adaptive_eps]
                elif isinstance(self.adaptive_eps, (int, float)):
                    self.global_eps_list = [float(self.adaptive_eps)]
                else:
                    raise TypeError(
                        "In global mode, adaptive_eps must be a non-negative number or a list of such numbers."
                    )
                self.global_eps_index = 0
                self.current_adaptive_eps = self.global_eps_list[0]
            elif isinstance(self.cycle_lengths, dict):
                # SA-wise mode: require scalar epsilon
                if isinstance(self.adaptive_eps, list):
                    raise ValueError(
                        "In state–action-wise mode, adaptive_eps must be a single non-negative number, not a list."
                    )
                self.current_adaptive_eps = float(self.adaptive_eps)

        # Trackers for adaptive early cut per (s,a) across the current cycle (global or SA-wise)
        # We use a cycle-local running arithmetic mean of the signed TD error (raw_update).
        if not self.perform_basic_q:
            self.pair_cycle_mean = {k: 0.0 for k in self.q_fct_update}
            self.pair_cycle_count = {k: 0 for k in self.q_fct_update}
            self.pair_below_eps_streak = {k: 0 for k in self.q_fct_update}
            # Track the latest global mean of per-pair cycle means across all (s,a)
            self.last_mean_m_hat = 0.0

        # Start at the beginning of the game
        self.current_state = self.env.start_state_num

        # If learning rate statewise is activated, store the current learning rates
        if self.learning_rate_state_action_wise:
            self.learning_rate_schedule_list = [
                {
                    action: self.learning_rate_kwargs["initial_rate"]
                    for action in self.env.allowed_actions[state]
                }
                for state in range(self.env.num_states)
            ]
            if self.learning_rate_kwargs["mode"] == "rate":
                self.learning_rate_iteration_num_list = [
                    {action: 1 for action in self.env.allowed_actions[state]}
                    for state in range(self.env.num_states)
                ]

    def __str__(self):
        return "ADP"

    def step(self) -> Tuple[int, int, int, float, bool]:
        """
        Exercises a step of the algorithm with the base parameters.
        Returns:
        - state (int): The state on which the algorithm was updated.
        - action (int): The action chosen during exploration in one step of the algorithm.
        - next_state (int): The resulting next state after playing the chosen action.
        - reward (float): The reward obtained by executing one step of the algorithm.
        - restart (bool): If a terminal state was reached and the game will be restarted.
        """

        # Optionally switch to i.i.d. (s,a) sampling mode
        if getattr(self, "uniform_state_action_sampling", False):
            # Sample a pair uniformly over all allowed (state, action)
            pair = sample_from_dist(
                self.policy.rng,
                "choice",
                1,
                **{
                    "a": self.all_state_action_pairs,
                    "p": self._uniform_sa_probs,
                },
            )[0]
            self.current_state, chosen_action = pair
            self.current_state = int(self.current_state)
            chosen_action = int(chosen_action)
        else:
            # Rollout-based action selection via the policy (existing behavior)
            # Use the *updating* network for behavior
            chosen_action = self.policy.choose_next_action(
                self.current_state, self.q_fct_update
            )

        # Save current state for the update
        state = self.current_state

        # Sample reward and next state given chosen action
        next_state, t, reward = self.env.get_next_state_and_reward(
            self.current_state, chosen_action
        )

        # Choose next step size according to the schedule and update the schedule
        if self.learning_rate_state_action_wise:
            # State-action-wise learning rate scheduling
            # Optionally reset learning rate schedule at the start of each cycle
            if self.lr_per_cycle:
                if self.learning_rate_state_action_wise:
                    if (
                        hasattr(self, "is_state_action_wise")
                        and self.is_state_action_wise
                    ):
                        # SA-wise reset: only for (s,a) pair at start of its own cycle
                        if self.sa_update_counters[(state, chosen_action)] == 0:
                            if self.learning_rate_kwargs["mode"] == "rate":
                                self.learning_rate_iteration_num_list[state][
                                    chosen_action
                                ] = 1
                            self.learning_rate_schedule_list[state][chosen_action] = (
                                self.learning_rate_kwargs["initial_rate"]
                            )
                    else:
                        # Global reset: reset all (s,a) pairs when at start of global cycle
                        if self.global_cycle_step_counter == 0:
                            for s in range(self.env.num_states):
                                for a in self.env.allowed_actions[s]:
                                    self.learning_rate_schedule_list[s][a] = (
                                        self.learning_rate_kwargs["initial_rate"]
                                    )
                                    if self.learning_rate_kwargs["mode"] == "rate":
                                        self.learning_rate_iteration_num_list[s][a] = 1
            stepsize = self.learning_rate_schedule_list[self.current_state][
                chosen_action
            ]
            if stepsize > self.learning_rate_kwargs["mode_kwargs"]["final_rate"]:
                if self.learning_rate_kwargs["mode"] == "rate":
                    iteration_num = self.learning_rate_iteration_num_list[
                        self.current_state
                    ][chosen_action]
                    self.learning_rate_kwargs["mode_kwargs"][
                        "iteration_num"
                    ] = iteration_num
                self.learning_rate_kwargs["current_rate"] = stepsize
                self.learning_rate_kwargs = schedule(**self.learning_rate_kwargs)
                self.learning_rate_schedule_list[self.current_state][chosen_action] = (
                    self.learning_rate_kwargs["current_rate"]
                )
                if self.learning_rate_kwargs["mode"] == "rate":
                    self.learning_rate_iteration_num_list[self.current_state][
                        chosen_action
                    ] = self.learning_rate_kwargs["mode_kwargs"]["iteration_num"]
        else:
            # Global learning rate scheduling
            if self.lr_per_cycle:
                if (
                    hasattr(self, "is_state_action_wise")
                    and not self.is_state_action_wise
                ):
                    if self.global_cycle_step_counter == 0:
                        self.learning_rate_kwargs = schedule(
                            reset_schedule=True, **self.learning_rate_kwargs
                        )
            stepsize = self.learning_rate_kwargs["current_rate"]
            if stepsize > self.learning_rate_kwargs["mode_kwargs"]["final_rate"]:
                self.learning_rate_kwargs = schedule(**self.learning_rate_kwargs)

        # Find action maximizing the Q Function at the next state using the *target* q_fct
        allowed_next_actions = self.env.allowed_actions[next_state]
        if allowed_next_actions:
            max_next_q_value = max(
                self.q_fct[(next_state, a)] for a in allowed_next_actions
            )
        else:
            max_next_q_value = 0

        # Q-learning update using target q_fct for bootstrap, updating q_fct_update
        raw_update = (
            reward
            + (1 - t) * self.gamma * max_next_q_value
            - self.q_fct_update[(state, chosen_action)]
        )
        abs_update = abs(raw_update)

        self.q_fct_update[(state, chosen_action)] = (
            self.q_fct_update[(state, chosen_action)] + stepsize * raw_update
        )
        # Remember the most recent abs_update for this (s,a)
        self.last_abs_update[(state, chosen_action)] = abs_update

        # Accumulate per-pair stats for the current cycle using a running arithmetic mean
        k = (state, chosen_action)
        if getattr(self, "adaptive_sync", False) and not self.perform_basic_q:
            # Update per-pair running mean of signed TD error (raw_update) within the current cycle
            self.pair_cycle_count[k] += 1
            n = self.pair_cycle_count[k]
            prev_mean = self.pair_cycle_mean[k]
            self.pair_cycle_mean[k] = prev_mean + (raw_update - prev_mean) / n

            # Compute the global mean-of-means for GLOBAL triggering:
            # average over absolute values of the per-(s,a) cycle means.
            m_sum = 0.0
            total_pairs = len(self.pair_cycle_mean)
            for kk in self.pair_cycle_mean:
                # Unvisited pairs in the current cycle contribute 0.0
                if self.pair_cycle_count[kk] > 0:
                    m_sum += abs(self.pair_cycle_mean[kk])
                else:
                    m_sum += 0.0
            if total_pairs > 0:
                self.last_mean_m_hat = m_sum / total_pairs

        # Synchronization logic
        if self.no_target:
            # Keep target equal to the updating network
            self.q_fct = self.q_fct_update.copy()

        if self.perform_basic_q:
            # In basic Q mode, target mirrors updates every step
            self.q_fct = self.q_fct_update.copy()
        else:
            # New cycle-based logic controls when we copy the updating network into the target (q_fct)
            if hasattr(self, "is_state_action_wise") and not self.is_state_action_wise:
                # Global periodic update mode
                self.global_step_counter += 1
                self.global_cycle_step_counter += 1
                # Decide whether to end the cycle: fixed threshold or adaptive (global mean EMA)
                trigger_end = False
                # Standard fixed threshold (global)
                if self.global_cycle_step_counter >= self.next_global_cycle_threshold:
                    trigger_end = True
                # Adaptive early cut based on global mean EMA (over all (s,a)), using a single global hysteresis
                elif getattr(self, "adaptive_sync", False):
                    total_updates_in_cycle = sum(self.pair_cycle_count.values())
                    min_needed = (
                        self.adaptive_min_cycle_steps
                        if isinstance(self.adaptive_min_cycle_steps, int)
                        else 1
                    )
                    if total_updates_in_cycle >= min_needed:
                        mean_m_hat = (
                            self.last_mean_m_hat
                        )  # reuse precomputed global mean EMA

                        if not hasattr(self, "global_below_eps_streak"):
                            self.global_below_eps_streak = 0
                        # Per-cycle epsilon schedule in global mode
                        eps_thresh = getattr(
                            self,
                            "current_adaptive_eps",
                            (
                                self.adaptive_eps
                                if isinstance(self.adaptive_eps, (int, float))
                                else (
                                    self.adaptive_eps[-1]
                                    if isinstance(self.adaptive_eps, list)
                                    and self.adaptive_eps
                                    else 0.0
                                )
                            ),
                        )
                        if mean_m_hat <= eps_thresh:
                            self.global_below_eps_streak += 1
                        else:
                            self.global_below_eps_streak = 0
                        if self.global_below_eps_streak >= self.adaptive_hysteresis_k:
                            trigger_end = True
                if trigger_end:
                    # Sync whole Q: start new global cycle
                    self.q_fct = self.q_fct_update.copy()
                    self.global_cycle_step_counter = 0
                    # Advance the cycle pointer following the provided schedule
                    if self.global_cycle_index < len(self.cycle_lengths_list) - 1:
                        self.global_cycle_index += 1
                    self.next_global_cycle_threshold = self.cycle_lengths_list[
                        self.global_cycle_index
                    ]
                    # Advance epsilon schedule analogously to cycle_lengths
                    if hasattr(self, "global_eps_list"):
                        if self.global_eps_index < len(self.global_eps_list) - 1:
                            self.global_eps_index += 1
                        self.current_adaptive_eps = self.global_eps_list[
                            self.global_eps_index
                        ]
                    # Reset adaptive trackers for ALL pairs (new global cycle)
                    for kk in self.pair_cycle_mean:
                        self.pair_cycle_mean[kk] = 0.0
                        self.pair_cycle_count[kk] = 0
                        self.pair_below_eps_streak[kk] = 0
                    # Reset global hysteresis after a global cycle close
                    if hasattr(self, "global_below_eps_streak"):
                        self.global_below_eps_streak = 0
                    self.last_mean_m_hat = 0.0
                    # Reset LR at cycle boundary if requested
                    if self.lr_per_cycle:
                        if self.learning_rate_state_action_wise:
                            for s in range(self.env.num_states):
                                for a in self.env.allowed_actions[s]:
                                    self.learning_rate_schedule_list[s][a] = (
                                        self.learning_rate_kwargs["initial_rate"]
                                    )
                                    if self.learning_rate_kwargs["mode"] == "rate":
                                        self.learning_rate_iteration_num_list[s][a] = 1
                        else:
                            self.learning_rate_kwargs = schedule(
                                reset_schedule=True, **self.learning_rate_kwargs
                            )
            elif hasattr(self, "is_state_action_wise") and self.is_state_action_wise:
                # State-action-wise periodic update mode
                k = (state, chosen_action)
                self.sa_update_counters[k] += 1
                # Decide whether to end the cycle for this pair: fixed threshold or adaptive (per-pair EMA)
                trigger_end = False
                # Fixed per-pair threshold (SA-wise)
                if self.sa_update_counters[k] >= self.sa_next_cycle_thresholds[k]:
                    trigger_end = True
                # Adaptive early cut for this pair using its own cycle mean + per-pair hysteresis
                elif getattr(self, "adaptive_sync", False):
                    per_pair_min = (
                        self.adaptive_min_cycle_steps[(state, chosen_action)]
                        if isinstance(self.adaptive_min_cycle_steps, dict)
                        else 1
                    )
                    if self.pair_cycle_count[k] >= per_pair_min:
                        m_hat = self.pair_cycle_mean[k]
                        if abs(m_hat) <= self.adaptive_eps:
                            self.pair_below_eps_streak[k] += 1
                        else:
                            self.pair_below_eps_streak[k] = 0
                        if self.pair_below_eps_streak[k] >= self.adaptive_hysteresis_k:
                            trigger_end = True
                if trigger_end:
                    # Sync only this entry: start new (s,a)-cycle
                    self.q_fct[k] = self.q_fct_update[k]
                    self.sa_update_counters[k] = 0
                    # Reset adaptive trackers for this pair (new (s,a) cycle)
                    self.pair_cycle_mean[k] = 0.0
                    self.pair_cycle_count[k] = 0
                    self.pair_below_eps_streak[k] = 0
                    # Advance the cycle index for this (s,a)
                    if k in self.cycle_lengths_dict:
                        if (
                            self.sa_cycle_indices[k]
                            < len(self.cycle_lengths_dict[k]) - 1
                        ):
                            self.sa_cycle_indices[k] += 1
                        self.sa_next_cycle_thresholds[k] = self.cycle_lengths_dict[k][
                            self.sa_cycle_indices[k]
                        ]
                    else:
                        self.sa_next_cycle_thresholds[k] = 1
                    # Reset LR for this (s,a) at cycle boundary (required in SA-wise mode)
                    if self.lr_per_cycle and self.learning_rate_state_action_wise:
                        s, a = k
                        self.learning_rate_schedule_list[s][a] = (
                            self.learning_rate_kwargs["initial_rate"]
                        )
                        if self.learning_rate_kwargs["mode"] == "rate":
                            self.learning_rate_iteration_num_list[s][a] = 1

        # Update current state
        self.current_state = next_state

        return state, chosen_action, next_state, reward, t

    def get_special_log_keys(self) -> List[Tuple[str, str]]:
        """Returns the log and plot names for all special plots and when to log them"""

        special_log_keys_list = []
        if "updated_q_values" in self.special_logs_kwargs.keys():
            special_log_keys_list.extend(
                [
                    ("at_eval", f"updated Q value of state {state} and action {action}")
                    for state in range(self.env.num_states)
                    for action in self.env.allowed_actions[state]
                ]
            )
        if "which_updated_q_values" in self.special_logs_kwargs.keys():
            special_log_keys_list.extend(
                [
                    ("at_eval", f"updated Q value of state {state} and action {action}")
                    for state_index, state in enumerate(
                        self.special_logs_kwargs["which_updated_q_values"][0]
                    )
                    for action in self.special_logs_kwargs["which_updated_q_values"][1][
                        state_index
                    ]
                ]
            )
        if "cycle_means" in self.special_logs_kwargs.keys():
            special_log_keys_list.extend(
                [
                    (
                        "at_eval",
                        f"cycle mean raw_update of state {state} and action {action}",
                    )
                    for state in range(self.env.num_states)
                    for action in self.env.allowed_actions[state]
                ]
            )
        if "which_cycle_means" in self.special_logs_kwargs.keys():
            special_log_keys_list.extend(
                [
                    (
                        "at_eval",
                        f"cycle mean raw_update of state {state} and action {action}",
                    )
                    for state_index, state in enumerate(
                        self.special_logs_kwargs["which_cycle_means"][0]
                    )
                    for action in self.special_logs_kwargs["which_cycle_means"][1][
                        state_index
                    ]
                ]
            )
        if "global_cycle_mean_of_means" in self.special_logs_kwargs.keys():
            special_log_keys_list.append(("at_eval", "global cycle mean of abs means"))
        return special_log_keys_list

    def get_special_logs_at_step(self) -> List[Tuple[str, Union[int, float]]]:
        # (intentionally unused; return empty list to avoid overhead at step time)
        return []

    def get_special_logs_at_epoch(self) -> List[Tuple[str, Union[int, float]]]:
        return []

    def get_special_logs_at_eval(self) -> List[Tuple[str, Union[int, float]]]:

        # Initialize the list for returning the values to log
        to_log_list = []

        if "updated_q_values" in self.special_logs_kwargs.keys():
            for state in range(self.env.num_states):
                for action in self.env.allowed_actions[state]:
                    label = f"updated Q value of state {state} and action {action}"
                    to_log_list.append((label, self.q_fct_update[(state, action)]))

        if "which_updated_q_values" in self.special_logs_kwargs.keys():
            for state_index, state in enumerate(
                self.special_logs_kwargs["which_updated_q_values"][0]
            ):
                for action in self.special_logs_kwargs["which_updated_q_values"][1][
                    state_index
                ]:
                    label = f"updated Q value of state {state} and action {action}"
                    to_log_list.append((label, self.q_fct_update[(state, action)]))

        if "cycle_means" in self.special_logs_kwargs.keys():
            for state in range(self.env.num_states):
                for action in self.env.allowed_actions[state]:
                    val = self.pair_cycle_mean[(state, action)]
                    label = (
                        f"cycle mean raw_update of state {state} and action {action}"
                    )
                    to_log_list.append((label, val))

        if "which_cycle_means" in self.special_logs_kwargs.keys():
            for state_index, state in enumerate(
                self.special_logs_kwargs["which_cycle_means"][0]
            ):
                for action in self.special_logs_kwargs["which_cycle_means"][1][
                    state_index
                ]:
                    val = self.pair_cycle_mean[(state, action)]
                    label = (
                        f"cycle mean raw_update of state {state} and action {action}"
                    )
                    to_log_list.append((label, val))

        if "global_cycle_mean_of_means" in self.special_logs_kwargs.keys():
            to_log_list.append(("global cycle mean of abs means", self.last_mean_m_hat))

        return to_log_list

    def get_greedy_policy(self) -> List[int]:
        """
        Takes the current Q function and gives out a policy list corresponding to the greedy policy.
        Returns:
        - list: The list of greedy actions with respect to the current Q function.
        """
        greedy_policy = []
        for state in range(self.env.num_states):
            allowed_actions = self.env.allowed_actions[state]
            vals = [self.q_fct[(state, a)] for a in allowed_actions]
            max_val = max(vals)
            arg_max = [a for a, v in zip(allowed_actions, vals) if v == max_val]
            if len(arg_max) == 1:
                greedy_policy.append(arg_max[0])
            else:
                greedy_policy.append(
                    int(
                        sample_from_dist(
                            self.policy.rng,
                            "choice",
                            1,
                            **{"a": arg_max, "p": [1 / len(arg_max) for _ in arg_max]},
                        )[0]
                    )
                )
        return greedy_policy

    def get_q_fct(self) -> Dict[Tuple[int, int], Union[int, float]]:
        """Returns the current estimate of the Q function (the *target* network view)."""
        return deepcopy(self.q_fct)

    def inputcheck(self) -> int:
        """
        Validates the input parameters to ensure they follow the expected formats and constraints.
        Raises:
        - ValueError: If any of the input parameters are invalid.
        - TypeError: If any of the input types are invalid.
        """
        # env is of type Env
        if not issubclass(self.env, envs.Env):
            raise TypeError("Environment needs to be of base type Env!")

        # env_kwargs is dictionary and does not contain rng_seed
        if isinstance(self.env_kwargs, dict):
            if "rng_seed" in self.env_kwargs.keys():
                raise ValueError(
                    "Environment key arguments should not contain rng_seed!"
                )
        else:
            raise TypeError(
                "Environment key arguments need to be contained in a dictionary!"
            )

        # Policy is of type policy
        if not issubclass(self.policy, Policy):
            raise TypeError("Policy needs to be of the correct type!")

        # policy_kwargs is dictionary and does not contain rng_seed,env_allowed_actions, env_num_states, or env_num_actions as keys
        if isinstance(self.policy_kwargs, dict):
            if "rng_seed" in self.policy_kwargs.keys():
                raise ValueError("Policy key arguments should not contain rng_seed!")
            if "env_allowed_actions" in self.policy_kwargs.keys():
                raise ValueError(
                    "Policy key arguments should not contain env_allowed_actions!"
                )
            if "env_num_state" in self.policy_kwargs.keys():
                raise ValueError(
                    "Policy key arguments should not contain env_num_state!"
                )
            if "env_num_actions" in self.policy_kwargs.keys():
                raise ValueError(
                    "Policy key arguments should not contain env_num_actions!"
                )
        else:
            raise TypeError(
                "Policy key arguments need to be contained in a dictionary!"
            )

        # learning_rate_kwargs is dictionary and is allowed for scheduling
        if isinstance(self.learning_rate_kwargs, dict):
            if "initial_rate" in self.learning_rate_kwargs.keys():
                self.learning_rate_kwargs["current_rate"] = self.learning_rate_kwargs[
                    "initial_rate"
                ]
                self.learning_rate_kwargs["mode_kwargs"]["iteration_num"] = 1
                check_for_schedule_allowed(**self.learning_rate_kwargs)
            else:
                raise ValueError(
                    "Initial rate is missing from the learning rate key word arguments!"
                )
        else:
            raise TypeError(
                "Learning rate keyword arguments need to be contained in a dictionary!"
            )

        # special_logs_kwargs is dictionary and is allowed for logging
        if isinstance(self.special_logs_kwargs, dict):
            for key in self.special_logs_kwargs.keys():
                if key not in self.allowed_special_logs_kwargs_keys:
                    raise ValueError(
                        "Invalid key for logging special parameters. If you tried implementing a new one register it in the self.allowed_special_logs_kwargs_keys list!"
                    )
        else:
            raise TypeError(
                "The special logs keyword arguments need to be passed in a dictionary!"
            )

        # learning_rate_state_action_wise is bool
        if not isinstance(self.learning_rate_state_action_wise, bool):
            raise TypeError(
                "Mode for statewise learning schedule needs to be a boolean value!"
            )

        # Gamma is float or int between 0 and 1, not including 1
        if isinstance(self.gamma, (int, float)):
            if 0 <= self.gamma <= 1:
                if 0 == self.gamma:
                    print(
                        "Your discount factor is set to zero! Proceed with caution, your MDP problem might not make sense!"
                    )
                elif 1 == self.gamma:
                    print(
                        "Your discount factor is set to one! Proceed with caution, your MDP problem might not make sense!"
                    )
            else:
                raise ValueError(
                    "The discount factor for your game needs to be between 0 and 1!"
                )
        else:
            raise TypeError("The discount factor needs to be a numerical value!")

        # q_fct_manual_init needs to be boolean
        if not isinstance(self.q_fct_manual_init, bool):
            raise TypeError("The variable q_fct_manual_init needs to be boolean!")

        # If the q function will be initialized manually, the initialization needs to be contained in a dictionary
        if self.q_fct_manual_init:
            if not isinstance(self.initial_q_fct, dict):
                raise TypeError(
                    "The Q function passed for manual initialization needs to be a dictionary!"
                )

        # Check cycle_lengths validity
        if self.cycle_lengths is not None:
            if isinstance(self.cycle_lengths, list):
                if not self.cycle_lengths or not all(
                    isinstance(x, int) and x > 0 for x in self.cycle_lengths
                ):
                    raise ValueError(
                        "cycle_lengths must be a non-empty list of positive integers."
                    )
            elif isinstance(self.cycle_lengths, dict):
                if not self.cycle_lengths:
                    raise ValueError("cycle_lengths dict cannot be empty.")
                for k, v in self.cycle_lengths.items():
                    if not (
                        isinstance(k, tuple)
                        and len(k) == 2
                        and all(isinstance(i, int) for i in k)
                    ):
                        raise TypeError(
                            "cycle_lengths dict keys must be (state, action) tuples."
                        )
                    if not (
                        isinstance(v, list)
                        and v
                        and all(isinstance(x, int) and x > 0 for x in v)
                    ):
                        raise ValueError(
                            "cycle_lengths dict values must be non-empty lists of positive integers."
                        )
            else:
                raise TypeError("cycle_lengths must be a list or a dict.")

        # Seed is in valid range:
        if isinstance(self.rng_seed, int):
            if not (0 <= self.rng_seed < 2**32):
                raise ValueError(
                    "The provided seed is not in the range of acceptable integer seeds!"
                )
        else:
            raise TypeError("The seed needs to be an integer!")

        if self.lr_per_cycle and self.cycle_lengths is None:
            raise ValueError("lr_per_cycle=True requires a non-None cycle_lengths.")
        if (
            self.lr_per_cycle
            and isinstance(self.cycle_lengths, dict)
            and not self.learning_rate_state_action_wise
        ):
            raise ValueError(
                "lr_per_cycle=True with state-action-wise cycle_lengths requires learning_rate_state_action_wise=True."
            )

        # Additional enforcement for SA-wise cycle mode
        if isinstance(self.cycle_lengths, dict):
            if not self.lr_per_cycle:
                raise ValueError(
                    "State–action-wise cycle_lengths (dict) requires lr_per_cycle=True to reset LR at (s,a) cycle boundaries."
                )
            if not self.learning_rate_state_action_wise:
                raise ValueError(
                    "State–action-wise cycle_lengths (dict) requires learning_rate_state_action_wise=True."
                )
        # AFTER — validate adaptive_eps depending on mode
        if isinstance(self.cycle_lengths, list):
            eps_val = getattr(self, "adaptive_eps", 0.0)
            if isinstance(eps_val, list):
                if not eps_val or any(
                    (not isinstance(x, (int, float)) or x < 0) for x in eps_val
                ):
                    raise ValueError(
                        "In global mode, adaptive_eps must be a non-empty list of non-negative numbers (or a single non-negative number)."
                    )
            elif not isinstance(eps_val, (int, float)) or eps_val < 0:
                raise ValueError(
                    "In global mode, adaptive_eps must be a non-negative number or list thereof."
                )
        elif isinstance(self.cycle_lengths, dict):
            eps_val = getattr(self, "adaptive_eps", 0.0)
            if not isinstance(eps_val, (int, float)) or eps_val < 0:
                raise ValueError(
                    "In state–action-wise mode, adaptive_eps must be a single non-negative number."
                )
        else:
            # Basic Q mode: allow either number or list; ignored in this mode
            pass
        # Mode-specific validation for adaptive_min_cycle_steps
        if isinstance(self.cycle_lengths, list):
            if (
                not isinstance(getattr(self, "adaptive_min_cycle_steps", 1), int)
                or self.adaptive_min_cycle_steps < 1
            ):
                raise ValueError(
                    "With global cycle_lengths (list), adaptive_min_cycle_steps must be an integer >= 1."
                )
        elif isinstance(self.cycle_lengths, dict):
            if not isinstance(getattr(self, "adaptive_min_cycle_steps", None), dict):
                raise ValueError(
                    "With state–action-wise cycle_lengths (dict), adaptive_min_cycle_steps must be a dict mapping (s,a) -> min updates before adaptive cut."
                )
            # Only structural validation here; key coverage is checked later once env/actions are known
            for kk, vv in self.adaptive_min_cycle_steps.items():
                if not (
                    isinstance(kk, tuple)
                    and len(kk) == 2
                    and all(isinstance(i, int) for i in kk)
                    and isinstance(vv, int)
                    and vv >= 1
                ):
                    raise ValueError(
                        "adaptive_min_cycle_steps dict must map (state:int, action:int) -> int >= 1."
                    )

        # Validate hysteresis parameter for adaptive early cut
        if (
            not isinstance(getattr(self, "adaptive_hysteresis_k", 3), int)
            or self.adaptive_hysteresis_k < 1
        ):
            raise ValueError("adaptive_hysteresis_k must be an integer >= 1.")

        return 1
