import numpy as np
from model.generative_model import Generative_Model

class Large_MDP(Generative_Model):
    def __init__(self, p=0.8, num_of_state = 2, num_of_action = 3, random_seed=42):
        self.random_seed = random_seed
        """ 
        Initialization: transition_map and reward_map, based on 20 states and 30 actions.
        Parameter:
        - p: Transition probability parameter for generating transition_map.
        """
        self.num_of_state = num_of_state
        self.num_of_action = num_of_action
        self._states = np.arange(num_of_state)  # state set: [0, 1, ..., 19]
        self._actions = np.arange(num_of_action)  # action set: [0, 1, ..., 29]

        self._sa_pairs = [(s, a) for s in self._states for a in self._actions]

        # random generator for transition_map and reward_map
        self.transition_map = self._generate_transition_map(p)
        self.reward_map = self._generate_reward_map()
        
    def _generate_transition_map(self, p: float):
        """ 
        Randomly generate transition probability distributions for each (state, action) pair.
        Parameters:
        - p: Transition probability parameter, controlling the bias between states.
        """
        transition_map = {}
        for sa in self._sa_pairs:
            state = sa[0]
            probabilities = np.random.rand(len(self._states))
            probabilities[state] = p
            probabilities = probabilities / probabilities.sum()  # normalize to probability distribution
            probabilities[-1] += 1.0 - probabilities.sum()
            transition_map[sa] = probabilities
        return transition_map

    def _generate_reward_map(self):
        """
        Randomly generate reward distributions for each (state, action) pair.
        Reward values are in the range [0, 1].
        """
        reward_map = {}
        global_seed = np.random.randint(0, 1e6)
        for sa in self._sa_pairs:
            unique_seed = hash((sa, global_seed)) % (2**32 - 1)
            rng = np.random.default_rng(unique_seed)
            p_1 = rng.random() 
            p_0 = 1 - p_1
            # For each (state, action) pair, randomly generate a reward distribution in [0, 1]
            reward_map[sa] = np.array([p_0, p_1])
        return reward_map

    @property
    def is_mdp(self):
        """ MDP Guarantee"""
        return True

    @property
    def states(self):
        """ Return states"""
        return self._states

    @property
    def rewards(self):
        """ Return rewards"""
        return np.array([0, 1])  # Reward values are [0, 1]

    @property
    def sa_pairs(self):
        """ Return sa pairs"""
        return self._sa_pairs

    @property
    def action_at_state(self):
        """ Return actions at each state."""
        return self.generate_action_at_state()