from typing import List, Tuple
import numpy as np

class MDP_no_R:
    """
    This class models the notion of tabular episodic finite-horizon MDP without
    reward with a single initial state.
    """
    S: int
    A: int
    H: int
    p: np.ndarray  # (H,S,A,S)
    s0: int
    mapp: np.ndarray
    n_objs: int

    def __init__(
            self,
            S: int,
            A: int,
            H: int,
            p: np.ndarray,
            s0: int,
            mapp: np.ndarray,
            n_objs: int
        ):
        """
        Constructor method for class MDP.

        Input arguments:
        - S: an integer representing the cardinality of the state space
        - A: an integer representing the cardinality of the action space
        - H: the horizon
        - p: the transition model, represented as a Hx(SxAxS) np.ndarray
        - s0: the initial state, an integer in {0,1,...,S-1}
        """
        if H < 2:
            raise Exception("Sorry, H must be at least 2")
        self.S = S
        self.A = A
        self.H = H
        self.p = p
        self.s0 = s0
        self.mapp = mapp
        self.n_objs = n_objs

    def compute_visit_distribution(self, pi: np.ndarray) -> Tuple[np.ndarray,np.ndarray]:
        """
        Take a policy np.ndarray (for each h, an np.ndarray
        SA-dimensional), and compute its visit distribution np.ndarray
        (for each h, an np.ndarray SA-dimensional). return first d_s, next d_sa
        """        
        # construct d_s and d_sa to H
        d_s = np.zeros((self.H, self.S))  # H x S -> [0,1]
        d_sa = np.zeros((self.H, self.S, self.A))  # H x S x A -> [0,1]
        
        # initialize d_s with s0 at h=1
        d_s[0,self.s0] = 1

        # initialize d_sa at h=1 with d_s and s0
        d_sa[0,self.s0,:] = pi[0,self.s0,:]
        
        # forward simulation
        for h in range(1, self.H):
            d_s[h] = np.tensordot(d_sa[h-1], self.p[h-1], axes=([0,1],[0,1]))
            d_sa[h] = d_s[h,:,None] * pi[h,:,:]
        
        return (d_s,d_sa)
    

    def collect_trajectories(self, pi: np.ndarray, n_trajs):
        """
        Collect n_trajs trajectories and return.
        """
        # initialize trajectories
        states = np.zeros((n_trajs, self.H+1), dtype=int)
        actions = np.zeros((n_trajs, self.H), dtype=int)

        # initial state is the same for all trajectories
        states[:, 0] = self.s0

        for h in range(self.H):
            # current states across all trajectories
            s = states[:, h]

            # sample actions from pi[h, s, :]
            probs = pi[h, s, :]                    # shape (n_trajs, A)
            a = np.array([np.random.choice(self.A, p=p_) for p_ in probs])
            actions[:, h] = a

            # sample next states from p[h, s, a, :]
            probs_next = self.p[h, s, a, :]             # shape (n_trajs, S)
            s_next = np.array([np.random.choice(self.S, p=p_) for p_ in probs_next])
            states[:, h+1] = s_next
        
        return states, actions
    

    def RF_Express(self, tau, n_episodes=10, delta=0.1, C=None):
        """
        Vectorized variant of algorithm RF-Express of paper "Fast active
        learning for pure exploration in reinforcement learning". We run for tau
        episodes. Instead of playing the greedy policy for just 1 episode, we
        play it in parallel on n_episodes, so that computation is faster.

        If C is not None, then drop the bonus accordingly.
        """
        S,A,H = self.S,self.A,self.H

        if C is None:
            C = 30*(H**2)


        # initialize feature counts, p_hat and W
        n = np.zeros((H,S,A,S))
        W = np.zeros((H+1,S,A))

        for t in range(tau//n_episodes):
            # -----------------------------
            # greedy policy wrt W
            # -----------------------------
            a_star = np.argmax(W[:H], axis=-1)               # (H, S)
            pi_greedy = np.eye(A)[a_star]                    # (H, S, A)

            # -----------------------------
            # collect trajectories in parallel
            # -----------------------------
            states, actions = self.collect_trajectories(pi_greedy, n_episodes)

            # -----------------------------
            # update counts n[h,s,a,s']
            # -----------------------------
            h_idx = np.repeat(np.arange(H), n_episodes)
            s_idx = states[:, :H].ravel()
            a_idx = actions.ravel()
            sp_idx = states[:, 1:H+1].ravel()

            flat_idx = ((h_idx * S + s_idx) * A + a_idx) * S + sp_idx
            counts = np.bincount(flat_idx, minlength=H*S*A*S).reshape(H, S, A, S)
            n += counts

            # -----------------------------
            # stopping criterion
            # -----------------------------
            if t == tau//n_episodes-1:
                # return normalized transitions if you really need them
                n_hsa = n.sum(axis=-1, keepdims=True)
                with np.errstate(divide="ignore", invalid="ignore"):
                    p_hat = np.divide(n, n_hsa, out=np.full_like(n, 1/S), where=n_hsa>0)
                return p_hat
            
            # -----------------------------
            # update W by backward induction
            # -----------------------------
            V_next = np.max(W[H], axis=-1)                   # (S,)
            for h in reversed(range(H)):
                # state-action counts
                n_hsa = n[h].sum(axis=-1)                    # (S,A)

                # expected value term = (n[h] @ V_next) / n_hsa
                EV = np.zeros((S,A))
                totals = n_hsa.copy()
                mask = totals > 0
                EV[mask] = (n[h][mask] @ V_next) / totals[mask]

                # optimism bonus
                with np.errstate(divide="ignore"):
                    beta = np.log(3*S*A*H/delta) + S*np.log(8*np.e*(totals+1))
                    if C is not None:
                        bonus = np.where(mask, 15 * (H**2) / C * beta / totals, H)
                    else:
                        bonus = np.where(mask, 15 * (H**2) * beta / totals, H)

                # update
                W[h] = np.where(mask[:, :],np.minimum(H, (1+1/H) * EV + bonus),H)

                V_next = np.max(W[h], axis=-1)               # (S,)

    def compute_visit_distribution_objects(self, pi: np.ndarray) -> np.ndarray:
         # initialize occupancy over objects
         d_objs = np.zeros(self.n_objs)

         # Compute occupancy over states
         d_s, _ = self.compute_visit_distribution(pi)

         # count occurrences of each object in mapp weighted by d_s
         d_objs = np.bincount(self.mapp.ravel(), weights=d_s.ravel(), minlength=self.n_objs)

         return d_objs
    
    def compute_optimal_policy_and_performance(
            self,
            r: np.ndarray,
    ) -> Tuple[List[np.ndarray],float]:
        """
        Take a reward and compute any optimal
        policy through value iteration.
        """
        # construct Q and pi to H, and V to H+1
        Q = np.zeros((self.H, self.S, self.A))
        pi = np.zeros((self.H, self.S, self.A))
        V = np.zeros((self.H+1, self.S))

        for h in range(self.H-1, -1, -1):
            # rewards per state (shape (S,))
            r_s = r[self.mapp[h]]  # self.mapp[h] has shape (S,)

            # compute expected future value for each (s,a)
            # (S, A) = einsum over s,a,s': p[h,s,a,s'] * V[h+1,s']
            EV = self.p[h].reshape(self.S*self.A, self.S) @ V[h+1]
            EV = EV.reshape(self.S, self.A)

            # add rewards
            Q[h] = r_s[:, None] + EV  # broadcast reward across actions

            # greedy policy and value
            a_star = np.argmax(Q[h], axis=1)       # shape (S,)
            V[h] = Q[h, np.arange(self.S), a_star] # optimal value
            pi[h, np.arange(self.S), a_star] = 1   # one-hot policy

        return pi, V[0,self.s0]
    
    def compute_policy_performance(
            self,
            r: np.ndarray,
            pi: List[np.ndarray],
    ) -> float: 
        """
        Given a reward and a policy, compute the expected return.
        """
        # compute prob of visiting objects
        d_objs = self.compute_visit_distribution_objects(pi)

        # compute expected return through dot product
        J = np.dot(d_objs, r)

        return J
        
        
        
            




