'''
All agents should inherit from the Agent class.

There are three common settings which we will examine:
- FiniteHorizonAgent = finite *known* horizon H
- EpisodicAgent = time-homogeneous problem with *unknown* episode length
- DiscountedAgent = infinite horizon with discount factor

Most work is presented for the FiniteHorizonAgent.

'''

import numpy as np

class Agent:

    def __init__(self):
        pass

    def update_obs(self, obs, action, reward, newObs):
        '''Add observation to records'''

    def update_policy(self):
        '''Update internal policy based upon records'''

    def pick_action(self, obs):
        '''Select an observation based upon the observation'''
    
class MDPAgent(Agent):
    '''
    An MDP Bayesian learner.
    '''
    
    def __init__(self, bayesian_model,seed=0, **kwargs):
        '''
        Args:
            bayesian_model - Bayesian model
            seed - The seed used to generate the random generator

        Returns:
            a Bayesian learner, to be inherited from
        '''

        # Instantiate the Bayes learner
        self.bayesian_model = bayesian_model
        self.update_policy()
        self.randomgenerator=np.random.default_rng(seed)

    def update_obs(self, oldState, action, reward, newState, done):
        '''
        Update the posterior belief based on one transition.

        Args:
            oldState - int
            action - int
            reward - double
            newState - int
            done - 0/1

        Returns:
            NULL - updates in place
        '''
        self.bayesian_model.update_history([(oldState, action, reward, newState, done)])#fixed bug

    def pick_action(self, state):

        action = self.policy(state)
        return action

    def egreedy(self, state, epsilon=0):
        '''
        Select action according to an epsilon-greedy policy

        Args:
            state - int - current state

        Returns:
            action - int
        '''


        if self.randomgenerator.random() < epsilon:
            action = self.sample_action()
        else:
            action = self.policy(state)


        return action


class FiniteHorizonMDPAgent(MDPAgent):
    '''
    A finite horizon MDP Bayesian learner.
    '''

    def __init__(self, bayesian_model, epLen,seed=0, **kwargs):
        '''
        Args:
            bayesian_model - Bayesian model
            epLen - episode length
            seed - The seed used to generate the random generator

        Returns:
            a Bayesian learner, to be inherited from
        '''
        self.epLen = epLen
        self.epIdx = 1
        self.episode_memory = []#Store a trial
        self.time_period = 0 # track the time period within an episode
        super().__init__(bayesian_model,seed)

    def update_obs(self, oldState, action, reward, newState, done):
        '''
        Update the posterior belief based on one transition.

        Args:
            oldState - int
            action - int
            reward - double
            newState - int
            done - 0/1

        Returns:
            NULL - updates in place
        '''

        self.episode_memory.append((oldState, action, reward, newState, done))
        self.time_period += 1

        if done:
            self.bayesian_model.update_history(self.episode_memory)
            self.episode_memory = []
            self.time_period = 0
            self.epIdx += 1
            self.update_policy()

    def egreedy(self, state, epsilon=0):
        '''
        Select action according to an epsilon-greedy policy

        Args:
            state - int - current state

        Returns:
            action - int
        '''


        if self.randomgenerator.random() < epsilon:
            action = self.sample_action()
        else:
            action = self.policy(state, self.time_period)


        return action

class EpisodicMDPAgent(MDPAgent):
    pass

class DiscountedMDPAgent(MDPAgent):
    pass
