from typing import Tuple, Union

import numpy as np

from icpe.MRP.mrp import MRP
from icpe.utils import compute_steady_dist


class BoyanChain(MRP):
    def __init__(self,
                 n_states: int,
                 gamma: float = 0.9,
                 weight: np.ndarray = None,
                 phi: np.ndarray = None) -> None:
        '''
        param n_states: number of states of the Boyan Chain
        param gamma: discount factor
        param noise: Gaussian noise added to the reward
        param phi: feature matrix of shape (n_states, d)
        '''
        super().__init__(n_states, gamma)

        # initialze transition matrix
        self.P = np.zeros((n_states, n_states))
        for i in range(n_states - 2):
            trans_sample = np.random.uniform(0.01, 0.99)
            self.P[i, i + 1] = trans_sample
            self.P[i, i + 2] = 1-trans_sample
        self.P[-2, -1] = 1.0
        self.P[-1, :] = np.random.uniform(0.01, 0.99, size=n_states)
        self.P[-1, :] /= self.P[-1, :].sum()
        assert np.allclose(self.P.sum(axis=1), 1)
        self.steady_d = compute_steady_dist(self.P)
        # set initial distribution to be the stationary distribution
        self.mu = self.steady_d

        if weight is not None:
            assert phi is not None, 'feature matrix X must be provided if weight is given'
            self.w = weight
            self.v = phi.dot(self.w)
            self.r = (np.eye(n_states) - gamma * self.P).dot(self.v)
        else:
            self.r = np.random.uniform(low=-1.0, high=1.0, size=(n_states, 1))
            self.v = np.linalg.inv(
                np.eye(n_states) - gamma * self.P).dot(self.r)

    def get_value(self) -> np.ndarray:
        '''
        return the value function of the MRP
        '''
        return self.v.copy()

    def get_steady_d(self) -> np.ndarray:
        '''
        return the steady state distribution of the MRP
        '''
        return self.steady_d.copy()

    def reset(self) -> int:
        s = np.random.choice(self.n_states, p=self.mu)
        return s

    def step(self, state: int) -> Tuple[int, float]:
        '''
        param state: current state
        return next state and reward
        '''
        assert 0 <= state < self.n_states
        next_state = np.random.choice(self.n_states, p=self.P[state])
        reward = self.r[state, 0]
        return next_state, reward

    def get_feature_index(self, state: Union[int, np.ndarray]) -> Union[int, np.ndarray]:
        '''
        param state: state or array of states
        '''
        if isinstance(state, np.ndarray):
            return state.astype(np.int32)
        else:
            return int(state)

    def sample_stationary(self) -> int:
        return np.random.choice(self.n_states, p=self.steady_d)

    def copy(self) -> 'BoyanChain':
        bc = BoyanChain(self.n_states, self.gamma)
        bc.P = self.P.copy()
        bc.mu = self.mu.copy()
        bc.r = self.r.copy()
        bc.v = self.v.copy()
        bc.steady_d = self.steady_d.copy()
        if hasattr(self, 'w'):
            bc.w = self.w.copy()
        return bc
