import numpy as np
from agent.policies import EGreedyPolicy, SoftmaxPolicy, ResMaxPolicy, MellowMaxPolicy, random_argmax

class TabularAgent(object):
    def __init__(self, env, agent_params):
        self.env = env
        self.agent_params: dict = agent_params

        self.num_actions: int = agent_params['ac_dim'] 
        self.exploration_strategy: str = agent_params['exploration_strategy']
        self.algorithm =  agent_params['algorithm'] 
        self.exp_value =  agent_params[ 'exp_value']
        self.step_size = agent_params['step_size'] 
        self.horizon = agent_params['horizon']  # integer if fixed horizon, else False 

        # set up learning algorithm
        if self.algorithm == 'q-learning':
            self.update = self.q_learn_update
        elif self.algorithm == 'expected-sarsa':
            self.update = self.expected_sarsa_update
        elif self.algorithm == 'sarsa':
            self.update = self.sarsa_update
        else:
            raise NotImplementedError
            # add more algorithms here

        # set up exploration technique/policy
        if self.exploration_strategy == 'resmax':
            self.policy = ResMaxPolicy(eta=self.exp_value)
        elif self.exploration_strategy == 'resmax-normalized':
            self.policy = ResMaxPolicy(
                eta=self.exp_value,
                normalize='non-expansion')
        elif self.exploration_strategy == 'resmax-normalized-td':
            self.policy = ResMaxPolicy(
                eta=self.exp_value,
                normalize='td',
                td_step_size=agent_params['td_step_size'],
                td_epsilon=agent_params['td_epsilon'],
                zeta=agent_params['zeta'])
        elif self.exploration_strategy == 'softmax':
            self.policy = SoftmaxPolicy(temp=self.exp_value)
        elif self.exploration_strategy == 'softmax-normalized':
            self.policy = SoftmaxPolicy(
                temp=self.exp_value,
                normalize=True,
                g_min=agent_params['g_min'], 
                g_max=agent_params['g_max'])
        elif self.exploration_strategy == 'softmax-normalized-td':
            self.policy = SoftmaxPolicy(
                temp=self.exp_value,
                normalize='td',
                td_step_size=agent_params['td_step_size'],
                td_epsilon=agent_params['td_epsilon'],
                zeta=agent_params['zeta'])
        elif self.exploration_strategy == 'mellowmax':
            self.policy = MellowMaxPolicy(
                omega=self.exp_value)
        else:
            self.policy = EGreedyPolicy(epsilon=self.exp_value)

        self.num_states = self.num_states(env)

        if self.horizon:
            self.q = np.ones((self.horizon, self.num_states, self.num_actions))
        else:
            self.q = np.ones((1, self.num_states, self.num_actions))

        self.q = self.q * agent_params['initial_optimism'] # if using optimistic initialization. The default is 0
        self.last_obs = None
        self.last_action = None

        self.gamma = agent_params['gamma']
        self.N = np.zeros((self.num_states,  self.num_actions))

        # used to keep track of agents policy over time
        self.record_policy = agent_params['save_policy'] 
        if self.record_policy:
            self.policy_log = np.zeros(agent_params['num_timesteps']//self.horizon)  # this will only work for riverswim! Recording the probablity of going right in state 1, horizon step 0
            
        self.timesteps = 0
        self.horizon_step = 0 # counter for indexing q based on horizon 

        self.states_visted = []
        self.actions_taken = []
        self.td_error = [] # to log the TD error
        self.zetas = []

    def start(self):
        '''
        initializes last_obs and last_action at the start of episode
        '''
        self.last_obs = self.env.reset()
        self.last_action = self.policy.get_action(self.q[self.horizon_step, self.state_rep(self.last_obs), :])

    def step_env(self):
        '''
        Observe new state and reward, then update values and take new action

        returns:
            done : (bool) True if episode has terminated, o/w False
            reward : (float) Reward received 
        '''        

        # This only works for riverswim. 
        if self.record_policy:
            if self.horizon_step == 0:
                ps = self.policy.get_p(self.q[self.horizon_step, 0, :]) # probability of each action in state 0, at start of episode
                p_right = ps[1]
                self.policy_log[self.timesteps//self.horizon] = p_right  

        # Define indices
        s = self.state_rep(self.last_obs)
        a = self.last_action # action taken in s

        ob, reward, done, _ = self.env.step(a)
        if done:
            self.q[self.horizon_step, s, a] = self.update(ob, reward, done)
            self.horizon_step = 0
            self.start() # Only does something in the finite horizon case: resets state and action
        else:
            self.q[self.horizon_step, s, a] = self.update(ob, reward, done) # update returns new number to be stored
            self.last_obs = ob
            s_next = self.state_rep(self.last_obs) # next state
            
            if self.horizon: # if finite horizon update the index
                self.horizon_step += 1
        
            self.last_action = self.policy.get_action(self.q[self.horizon_step, s_next, :]) # select next action
        
        # Adjust counters
        self.timesteps += 1
        self.N[s, a] += 1
        return done, reward

    def offline_eval(self, num_episodes, eval_env):
        ob = eval_env.reset()
        episodic_steps = []
        episodic_returns = []

        for ep_idx in range(num_episodes):
            steps = 0
            rewards = []
            done = False
            horizon_step = 0
            
            while not done:
                a = random_argmax(self.q[horizon_step, self.state_rep(ob), :])
                ob, reward, done, info = eval_env.step(a)
                rewards.append(reward)
                steps += 1

                if self.horizon:
                    horizon_step += 1

            # if taking this step resulted in done, break, reset the env (and the latest observation)
            ob = eval_env.reset()
            horizon_step = 0
            episodic_steps.append(steps)
            episodic_returns.append(np.sum(rewards))
                    
        return episodic_returns, episodic_steps


    def q_learn_update(self, ob, reward, done):
        # first, define indices
        s = self.state_rep(self.last_obs)
        a = self.last_action
        h = self.horizon_step
        
        if done:
            target = reward
        else:
            s_next = self.state_rep(ob) # index of next state
            
            if self.horizon:
                max_a = np.max(self.q[h+1, s_next, :]) # need to increase horizon step by 1
            else:
                max_a = np.max(self.q[h, s_next, :])
            
            target = reward + self.gamma * max_a 
    
        alpha = self.alpha(s,a)
        q_s_a = self.q[h,s,a]

        if '-normalized-td' in self.exploration_strategy: 
            self.policy.update_zeta(target-q_s_a)
            self.zetas.append(self.policy.zeta)
        
        # Log stuff
        self.td_error.append(target-q_s_a) 
        self.states_visted.append(s)
        self.actions_taken.append(a)

        q_s_a = (1-alpha)*q_s_a + alpha*target 

        return q_s_a


    def sarsa_update(self, ob, reward, done):
        # first, define indices
        s = self.state_rep(self.last_obs)
        a = self.last_action
        h = self.horizon_step
        
        if done:
            target = reward
        else:
            s_next = self.state_rep(ob) # index of next state
            
            if self.horizon:
                a_next = self.policy.get_action(self.q[h+1, s_next, :])
                next_q = self.q[h+1, s_next, a_next]
                
            else:
                a_next = self.policy.get_action(self.q[h, s_next, :])
                next_q = self.q[h, s_next, a_next]

            target = reward + self.gamma*next_q
    
        alpha = self.alpha(s,a)
        q_s_a = self.q[h,s,a]


        new_q_s_a = (1-alpha)*q_s_a + alpha*target 
      

        return new_q_s_a


    def expected_sarsa_update(self, ob, reward, done):
        # first, define indices
        s = self.state_rep(self.last_obs)
        a = self.last_action
        h = self.horizon_step
       
        if done:
            target = reward 
        else:
            s_next = self.state_rep(ob) # index of next state

            if self.horizon:
                q_s_next = self.q[h+1 , s_next, :] # need to increase horizon step by 1
            else:
                q_s_next = self.q[h, s_next, :] 
    

            p = self.policy.get_p(q_s_next) # probability of each action in the next state
            target = reward + self.gamma*np.dot(p, q_s_next.flatten()) 
        
        alpha = self.alpha(s,a)
        q_s_a = self.q[h,s,a]

        if '-normalized-td' in self.exploration_strategy:
            self.policy.update_zeta(target-q_s_a)
            self.zetas.append(self.policy.zeta)

        # Log stuff
        self.td_error.append(target-q_s_a) 
        self.states_visted.append(s)
        self.actions_taken.append(a)


        q_s_a = (1-alpha)*q_s_a + alpha*target

        return q_s_a

    def alpha(self, s, a):
        if self.step_size is not None:
            alpha = self.step_size
        else:

            alpha = 1/(self.timesteps//100000 + 1)
        return alpha

    def state_rep(self, ob):
        '''
        Turns state into index for table of values

        Inherit this class and implement this function wrt environment
        '''
        raise NotImplementedError

    def num_states(self, env):
        '''
        returns number of states for the given environment

        Inherit this class and implement this function wrt environment
        '''
        raise NotImplementedError


class DeepSeaTabularAgent(TabularAgent):
    def state_rep(self, ob):
        n = ob.shape[0]
        pos= np.argwhere(ob)[0]
        return pos[0]*n + pos[1]

    def num_states(self, env):
        n = env.observation_spec().shape[0]
        return n**2


class RiverSwimTabularAgent(TabularAgent):
    def state_rep(self, ob):
        return ob

    def num_states(self, env):    
        # Hard coded 
        return 6

class HardSquareTabularAgent(TabularAgent):
    def state_rep(self, ob):
        return ob

    def num_states(self, env):    
        # Hard coded 
        return 4

class TwoStateTabularAgent(TabularAgent):
    def state_rep(self, ob):
        return ob

    def num_states(self, env):    
        # Hard coded 
        return 2