import math
from typing import Tuple, Union

import numpy as np

from icpe.MRP.mrp import MRP
from joblib import Parallel, delayed


class CartPole(MRP):
    """Credit : https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py#L7"""

    def __init__(self,
                 bins_per_feature: int,
                 gamma: float):
        '''
        param bins_per_feature: number of bins per feature
        param gamma: discount factor
        '''

        # self.gravity = 9.8
        # gravity is uniformly distributed
        self.gravity = np.random.uniform(low=7, high=12)
        # self.masscart = 1.0 # masscart is uniformly distributed
        # masscart is uniformly distributed
        self.masscart = np.random.uniform(low=0.5, high=1.5)
        # self.masspole = 0.1 # masspole is uniformly distributed
        # masspole is uniformly distributed
        self.masspole = np.random.uniform(low=0.05, high=0.15)
        self.total_mass = (self.masspole + self.masscart)
        # actually half the pole's length
        self.length = np.random.uniform(low=0.5, high=1.5)
        self.polemass_length = (self.masspole * self.length)
        # self.force_mag = 10.0
        # force_mag is uniformly distributed
        self.force_mag = np.random.uniform(low=5, high=15)
        # self.tau = 0.02  # seconds between state updates
        # tau is uniformly distributed
        self.tau = np.random.uniform(low=0.01, high=0.05)
        self.kinematics_integrator = 'euler'
        self.epsilon = np.random.rand()  # epsilon controls the action distribution

        # Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        # print('Theta threshold:', self.theta_threshold_radians)
        # Position at which to fail the episode
        self.x_threshold = 2.4

        # Number of bins per dimension
        self.n_bins = bins_per_feature
        # Define the observation boundaries for each state variable
        # state = (x, x_dot, theta, theta_dot)
        self.obs_bounds = [[-self.x_threshold * 1.25, self.x_threshold * 1.25],
                           [-2.5, 2.5],
                           [-self.theta_threshold_radians * 1.25,
                               self.theta_threshold_radians * 1.25],
                           [-2.5, 2.5]]
        # Create bins for each dimension
        self.bins = [
            # Exclude the first and last bin edges
            np.linspace(low, high, self.n_bins + 1)[1:-1]
            for low, high in self.obs_bounds
        ]
        super().__init__(self.n_bins ** 4, gamma)
        self.total_states = self.n_bins ** 4

        # Action space
        self._all_actions = [0, 1]  # left, right

        self.rewards = np.random.uniform(low=-1, high=1,
                                         size=self.n_states)

        self.v = None
        self.steady_d = None

    def _estimate_stationary_and_value(self, n_states: int, gamma: float) -> Tuple[np.ndarray, int, float]:
        '''
        param n_states: number of states
        param gamma: discount factor
        '''
        # count the visit of each state in the trajectory
        total_visit_count = np.zeros(n_states)

        state = self.reset()
        state_idx = self.get_discretized_state_idx(state)
        initial_state_idx = state_idx
        total_visit_count[state_idx] += 1

        # Monte Carlo Simulation (1000 steps)
        total_reward = 0.0
        discount = 1.0
        for _ in range(1000):
            state, reward = self.step(state)
            state_idx = self.get_discretized_state_idx(state)
            total_visit_count[state_idx] += 1
            total_reward += reward * discount
            discount *= gamma
        return total_visit_count, initial_state_idx, total_reward

    def get_value(self) -> np.ndarray:
        if self.v is None:
            # run many trajectories to estimate the stationary distribution and value function
            results = Parallel(n_jobs=-1)(delayed(self._estimate_stationary_and_value)(self.n_states, self.gamma)
                                          for _ in range(self.n_states * 30))
            total_visit_count = np.zeros(self.n_states)
            initial_state_count = np.zeros(self.n_states)
            v = np.zeros(self.n_states)
            for visit_count, initial_state, value in results:
                total_visit_count += visit_count
                initial_state_count[initial_state] += 1
                v[initial_state] += value

            self.v = np.where(initial_state_count > 0, v /
                              initial_state_count, 0).reshape(self.n_states, 1)
            self.steady_d = total_visit_count / total_visit_count.sum()

        return self.v.copy()

    def get_steady_d(self) -> np.ndarray:
        if self.steady_d is None:
            # run many trajectories to estimate the stationary distribution and value function
            results = Parallel(n_jobs=-1)(delayed(self._estimate_stationary_and_value)(self.n_states, self.gamma)
                                          for _ in range(self.n_states * 30))
            total_visit_count = np.zeros(self.n_states)
            initial_state_count = np.zeros(self.n_states)
            v = np.zeros(self.n_states)
            for visit_count, initial_state, value in results:
                total_visit_count += visit_count
                initial_state_count[initial_state] += 1
                v[initial_state] += value

            self.v = np.where(initial_state_count > 0, v /
                              initial_state_count, 0).reshape(self.n_states, 1)
            self.steady_d = total_visit_count / total_visit_count.sum()

        return self.steady_d.copy()

    def is_state_valid(self, state: np.ndarray) -> bool:
        """
        param: state current state of the agent
        return: True if the state is valid, False otherwise
        """
        x, _, theta, _ = state
        # Velocities aren't bounded, therefore cannot be checked.
        is_state_invalid = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or theta < -self.theta_threshold_radians
            or theta > self.theta_threshold_radians
        )
        return not is_state_invalid

    def reset(self) -> np.ndarray:
        """get a random starting position."""
        state = np.random.uniform(low=-0.05, high=0.05, size=(4,))
        return state

    def step(self, state: np.ndarray) -> Tuple[np.ndarray, float]:
        '''
        param: state current state of the agent
        '''
        x, x_dot, theta, theta_dot = state
        # action is random with probability epsilon
        action = np.random.binomial(1, self.epsilon)
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        # For the interested reader:
        # https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (force + self.polemass_length * theta_dot **
                2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (self.length *
                                                                  (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == 'euler':
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:  # semi-implicit euler
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        next_state = np.array((x, x_dot, theta, theta_dot))

        # if we fall outside of the range, give a reward of -1 and then reset
        if not self.is_state_valid(next_state):
            next_state = self.reset()

        discretized_state_idx = self.get_discretized_state_idx(next_state)
        reward = self.rewards[discretized_state_idx]

        return next_state, reward

    def _discretize_state(self, state: np.ndarray) -> np.ndarray:
        """
        param state: current state of the agent
        discretize the continuous state into bins for each dimension.
        """
        discretized_state = []
        for i, s in enumerate(state):
            # For each dimension, find which bin s belongs to
            bin_indices = np.digitize(s, self.bins[i])
            discretized_state.append(bin_indices)
        return np.array(discretized_state)

    def get_discretized_state_idx(self, state: np.ndarray) -> int:
        """
        param state: current state of the agent
        get the state index for the state.
        """
        discretized_state = self._discretize_state(state)
        # Calculate a unique index for the discretized state
        feature_index = 0
        num_bins = self.n_bins
        for i, bin_idx in enumerate(discretized_state):
            feature_index *= num_bins
            feature_index += bin_idx
        return feature_index

    def get_feature_index(self, state: np.ndarray) -> Union[int, np.ndarray]:
        '''
        param state: state or array of states
        '''
        if state.ndim == 1:
            return self.get_discretized_state_idx(state)
        else:
            return np.array([self.get_discretized_state_idx(s) for s in state], dtype=np.int32)
