import numpy as np
from tqdm import tqdm

from env import GridWorld, DummyEnv

import logging

logger = logging.getLogger(__name__)


class StatisticalQueryAlgo:
    """The statistical query algorithm."""

    def __init__(
        self,
        replicable: bool = False,
        tolerance_param: float = None,
        reprod_param: float = None,
        failure_param: float = None,
    ) -> None:
        """Initialize the statistical query algorithm.

        Args:
            replicable (bool, optional): Whether to use the replicable
                version of the algorithm. Defaults to False.
            tolerance_param (float, optional): The tolerance parameter.
                Defaults to None.
            reprod_param (float, optional): The reproductibility parameter.
                Defaults to None.
            failure_param (float, optional): The failure parameter.
                Defaults to None.
        """
        if replicable:
            assert (
                tolerance_param is not None
                and reprod_param is not None
                and failure_param is not None
            ), "tolerance_param, reprod_param, and failure_param cannot be None"
        self.tolerance_param = tolerance_param
        self.reprod_param = reprod_param
        self.failure_param = failure_param
        self.replicable = replicable

    def _statistical_query(self, sample: np.ndarray) -> float:
        """
        Compute the mean of a sample.

        Args:
            sample (np.ndarray): The sample.

        Returns:
            float: The mean of the sample.
        """
        sample_mean = np.mean(sample)
        return sample_mean

    def _replicable_statistical_query(
        self, sample: np.ndarray, randomness: int
    ) -> float:
        """
        Compute the replicable mean of a sample.

        Args:
            sample (np.ndarray): The sample.

        Returns:
            float: The replicable mean of the sample.
        """
        assert randomness is not None, "randomness cannot be None"
        rng = np.random.RandomState(randomness)

        alpha = (2 * self.tolerance_param) / (
            self.reprod_param + 1 - 2 * self.failure_param
        )
        alpha_offset = rng.uniform(0, alpha)
        num_bins = int(1 / alpha) + 2

        # compute the regions for rounding the sample mean
        regions = np.zeros((num_bins,))
        regions[0] = 0
        regions[1:-1] = np.arange(0, num_bins - 2) * alpha + alpha_offset
        regions[-1] = 1

        # compute the sample mean between 0 and 1
        sample_mean = np.mean(sample)

        # round the sample mean down to the nearest region
        # and return the midpoint of that region
        lowest_index = np.argmin(np.abs(regions - sample_mean))

        if sample_mean < regions[lowest_index]:
            sample_mean = (
                regions[lowest_index - 1]
                + (regions[lowest_index] - regions[lowest_index - 1]) / 2
            )
        else:
            sample_mean = (
                regions[lowest_index]
                + (regions[lowest_index + 1] - regions[lowest_index]) / 2
            )

        return sample_mean

    def compute_statistical_query(
        self, sample: np.ndarray, randomness: int = None
    ) -> float:
        """Compute the statistical query.

        Args:
            sample (np.ndarray): The sample.
            randomness (int, optional): The randomness to use for the
                replicable statistical query. Defaults to None.

        Returns:
            float: The statistical query.
        """
        if self.replicable:
            return self._replicable_statistical_query(sample, randomness)
        else:
            return self._statistical_query(sample)

    def __call__(self, sample: np.ndarray, *args, **kwargs) -> float:
        """Compute the statistical query."""
        return self.compute_statistical_query(sample, *args, **kwargs)


class PVI:
    """The PVI algorithm."""

    def __init__(
        self,
        acc_param: float,
        failure_param: float,
        sample_divisor: int,
    ) -> None:
        """Initialize the PVI algorithm.

        Args:
            acc_param (float): The accuracy parameter.
            sample_divisor (int): The sample divisor.
            querier (StatisticalQueryAlgo): The statistical query algorithm.
        """
        self.acc_param = acc_param
        self.failure_param = failure_param
        self.sample_divisor = sample_divisor

    def build(self, env: GridWorld) -> None:
        """Instantiate the environment and initialize the Q-values.

        Args:
            env (GridWorld): The environment.
        """
        self.env = env
        self.Q_values = {
            (s, a): 0
            for s in range(self.env.num_states)
            for a in range(self.env.num_actions)
        }

        self.num_iterations = int(
            np.log(2 / ((1 - env.gamma) ** 2 * self.acc_param)) / (1 - env.gamma)
        )
        logger.info(f"Setting num_iterations to {self.num_iterations}")

        self.num_calls_to_ps = int(
            np.log(1 / self.acc_param)
            / (self.acc_param) ** 2
            * np.log(self.env.num_states * self.env.num_actions / self.failure_param)
            / self.sample_divisor
        )
        logger.info(f"Setting num_calls_to_ps to {self.num_calls_to_ps}")

        self.querier = StatisticalQueryAlgo(
            replicable=False,
            tolerance_param=self.acc_param,
            failure_param=self.failure_param,
        )

    def _iteration(self, i: int):
        """Run one iteration of the PVI algorithm.

        Args:
            i (int): The iteration number.
        """
        next_states = {}
        next_states, r = self.env.parallel_sampling(self.num_calls_to_ps)

        for j, k in enumerate(self.Q_values.keys()):
            randomness = i * len(self.Q_values.keys()) + j
            relevant_q_values = [
                np.max([self.Q_values[(s, a)] for a in range(self.env.num_actions)])
                for s in next_states[k]
            ]
            expected_value = self.querier(relevant_q_values, randomness=randomness)

            self.Q_values[k] = r[k] + self.env.gamma * expected_value

    def run(self):
        """Run the PVI algorithm.

        Returns:
            Tuple[np.ndarray, Dict[int, int]]: The V-values and the policy.
        """
        for i in tqdm(range(self.num_iterations)):
            self._iteration(i)

        policy = {
            s: np.argmax([self.Q_values[(s, a)] for a in range(self.env.num_actions)])
            for s in range(self.env.num_states)
        }

        v_values = q_values_to_v_values(
            self.Q_values, self.env.num_states, self.env.num_actions
        )

        return_dict = {
            "vf": v_values,
            "policy": policy,
        }

        return return_dict


class ReplicablePVI(PVI):
    """Replicable PVI algorithm.

    Args:
        PVI (PVI): The PVI algorithm.
    """

    def __init__(
        self,
        acc_param: float,
        sample_divisor: int,
        reprod_param: float,
        failure_param: float,
        set_calls_to: int = None,
    ) -> None:
        """Initialize the replicable PVI algorithm.

        Args:
            acc_param (float): Accuracy parameter.
            sample_divisor (int): Divisor for the number of samples.
            querier (StatisticalQueryAlgo): Statistical query algorithm.
            reprod_param (float): Reproducibility parameter.
            failure_param (float): Failure parameter.
        """
        super().__init__(acc_param, failure_param, sample_divisor)
        self.reprod_param = reprod_param
        self.set_calls_to = set_calls_to
        self.replicable = True

    def build(self, env: GridWorld) -> None:
        """Instantiate the environment and set the number of iterations.

        Args:
            env (GridWorld): GridWorld environment.
        """
        super().build(env)

        if self.set_calls_to is not None:
            self.num_calls_to_ps = self.set_calls_to
            logger.info(f"Overriding num_calls_to_ps to {self.num_calls_to_ps}")

        self.querier = StatisticalQueryAlgo(
            replicable=True,
            tolerance_param=self.acc_param * (1 - self.env.gamma),
            failure_param=self.failure_param
            / self.env.num_states
            / self.env.num_actions,
            reprod_param=self.reprod_param / self.env.num_states / self.env.num_actions,
        )


class ValueIteration:
    def __init__(self, acc_param: float, max_iter: int, verbose: bool) -> None:
        self.acc_param = acc_param
        self.max_iter = max_iter
        self.verbose = verbose

    def build(self, env: GridWorld) -> None:
        """Instantiate the environment.

        Args:
            env (GridWorld): GridWorld environment.
        """
        self.env = env

    def run(self):
        """Run the value iteration algorithm."""
        # initialize the value function and policy
        pi = np.ones((self.env.num_states, 1))
        value_function = np.zeros((self.env.num_states, 1))

        for i in range(self.max_iter):
            # initialize delta
            delta = 0
            # perform Bellman update for each state
            for state in range(self.env.num_states):
                # store old value
                tmp = value_function[state].copy()
                # compute the value function
                value_function[state] = np.max(
                    self.env.R[state, :, 0]
                    + np.sum(
                        self.env.P[state, :, :] * (self.env.gamma * value_function),
                        0,
                    )
                )

                # find maximum change in value
                delta = np.max((delta, np.abs(tmp[0] - value_function[state][0])))
            # stopping criteria
            if delta <= self.acc_param * (1 - self.env.gamma) / self.env.gamma:
                if self.verbose:
                    logger.info(f"Converged after {i} iterations.")
                break
        # compute the policy
        for state in range(self.env.num_states):
            pi[state] = np.argmax(np.sum(value_function * self.env.P[state, :, :], 0))

        return_dict = {
            "vf": value_function,
            "policy": pi,
        }

        return return_dict


def stack_dicts(base_dict: dict, to_stack_dict: dict) -> dict:
    """
    Stack two dictionaries with the same keys.

    Args:
        base_dict (dict): The base dictionary.
        to_stack_dict (dict): The dictionary to stack on top of the base dictionary.

    Returns:
        dict: The stacked dictionary.
    """
    stacked_dict = {}
    for key in to_stack_dict.keys():
        if key not in base_dict.keys():
            stacked_dict[key] = [to_stack_dict[key]]
        else:
            stacked_dict[key] = base_dict[key] + [to_stack_dict[key]]

    return stacked_dict


def q_values_to_v_values(
    q_values: dict, num_states: int, num_actions: int
) -> np.ndarray:
    """Translate a dict of qvalues for every state action pair to v values for every state.

    Args:
        q_values (dict): A dict of qvalues for every state action pair.

    Returns:
        np.ndarray: A np.ndarray of v values for every state.
    """
    v_values = np.zeros((num_states, 1))
    for s in range(num_states):
        v_values[s] = np.max([q_values[(s, action)] for action in range(num_actions)])
    return v_values


def test_replicable_statistical_query():
    """
    Test the replicable_statistical_query function.
    """
    tolerance_param = 0.005
    reprod_param = 0.5
    failure_param = 0.0001

    querier = StatisticalQueryAlgo(
        replicable=True,
        tolerance_param=tolerance_param,
        reprod_param=reprod_param,
        failure_param=failure_param,
    )

    sample_means = []
    for i in tqdm(range(1000)):
        rng = np.random.RandomState(i)
        sample_size = int(
            1 / (2 * tolerance_param**2 * (reprod_param - 2 * failure_param) ** 2)
        )
        sample = rng.uniform(0, 1, sample_size)

        replicable_sample_mean = querier(
            sample,
            randomness=1,
        )
        sample_means.append(replicable_sample_mean)

    print(np.unique(sample_means, return_counts=True))


if __name__ == "__main__":
    test_replicable_statistical_query()
