import numpy as np
from utils.policies import EGreedyPolicy, SoftmaxPolicy, ResMaxPolicy, random_argmax, MellowmaxPolicy
import utils.tiles3 as tiles3


class LinearAgent(object):
    def __init__(self, env, agent_params, preprocess_obs=None):

        self.env = env
        self.agent_params: dict = agent_params
        # self.last_obs: np.ndarray = self.env.reset()
        self.tile_coding = agent_params['tile_coding']
        if self.tile_coding:
            self.iht = None
            self.num_tilings = agent_params['num_tilings']
            self.num_tiles = agent_params['num_tiles']
            self.iht_size = agent_params['iht_size']
            self.tc = TileCoder(iht_size=self.iht_size, 
                                            num_tilings=self.num_tilings, 
                                            num_tiles=self.num_tiles,
                                            ob_dim=agent_params['ob_dim'],
                                            ob_low=agent_params['ob_low'],
                                            ob_high=agent_params['ob_high'])
            self.previous_tiles = None
            self.num_states = self.iht_size
        else:
            # Discrete state space
            self.num_states = self.num_states(env)
        
        self.num_actions: int = agent_params['ac_dim'] 
        self.ep_len = agent_params['ep_len']
        self.exploration = agent_params['exploration_schedule']
        self.exploration_strategy: str = agent_params['exploration_strategy']

        # self.preprocess_obs = preprocess_obs
        self.t: int = 0
        
        self.algorithm =  agent_params['algorithm'] 
        if self.algorithm == 'q-learning':
            self.update = self.q_learn_update
        elif self.algorithm == 'expected_sarsa':
            self.update = self.expected_sarsa_update
        else:
            raise NotImplementedError
        # add more algorithms here

        self.normalization = agent_params['normalization_scheme']
        self.g_min = agent_params['g_min']
        self.g_max = agent_params['g_max']

        self.exploration_value: float = self.exploration.value(self.t)
        if self.exploration_strategy == 'ResMax':
            self.policy = ResMaxPolicy(eta = self.exploration_value, normalization=self.normalization, 
            g_min=self.g_min, g_max=self.g_max, td_step_size=agent_params['td_step_size'], 
            td_epsilon=agent_params['td_epsilon'], zeta=agent_params['zeta'])
        elif self.exploration_strategy == 'softmax':
            self.policy = SoftmaxPolicy(temp = self.exploration_value, normalization=self.normalization, 
            g_min=self.g_min, g_max=self.g_max, td_step_size=agent_params['td_step_size'], 
            td_epsilon=agent_params['td_epsilon'], zeta=agent_params['zeta'])
        elif self.exploration_strategy == 'mellowmax':
            self.policy = MellowmaxPolicy(omega = self.exploration_value, normalization=self.normalization, 
            g_min=self.g_min, g_max=self.g_max, td_step_size=agent_params['td_step_size'], 
            td_epsilon=agent_params['td_epsilon'], zeta=agent_params['zeta'])
        elif self.exploration_strategy == 'epsilon-greedy':
             self.policy = EGreedyPolicy(epsilon = self.exploration_value)
        else:
            raise NotImplementedError
        
        # NOTE: intitializing W to zeros for now. Also, num_actions same for all states
        self.init_W = agent_params['init']
        if agent_params['rand_init'] == 0:
            self.W = np.ones((self.num_actions, self.num_states)) * self.init_W
        else:
            self.W = np.random.random((self.num_actions, self.num_states)) * self.init_W

        self.last_obs = None
        self.last_action = None
        self.last_action_value = None

        self.gamma = agent_params['gamma']
        self.N = np.zeros((self.num_actions, self.num_states))

        self.step_size = agent_params['step_size']
        if self.step_size == 0:
            self.step_size_schedule = True
        else:
            self.step_size_schedule = False
            # if self.tile_coding:
            #     self.step_size /= self.num_tilings
        
        # self.td_error = [] # to log the TD error
        # self.weights_changes = [] # to log the weights changes

    def start(self):
        '''
        initializes last_obs and last_action at the start of episode
        '''
        self.last_obs = self.env.reset()
        if self.tile_coding:
            self.last_obs = self.tc.get_tiles(self.last_obs)
        action_values = np.zeros(self.num_actions)
        for i in range(self.num_actions):
            action_values[i] = self.W[i][self.state_rep(self.last_obs)].sum()
        self.last_action = self.policy.get_action(action_values, exp_value=self.exploration.value(self.t))
        self.last_action_value = action_values[self.last_action]
    
    def step_env(self):
        '''
         observe new state and reward, then update values and take new action
        '''
        # if self.preprocess_obs is not None:
        #    self.last_obs = self.preprocess_obs(self.last_obs)
        
        self.N[self.last_action][self.state_rep(self.last_obs)] += 1
        ob, reward, done, _ = self.env.step(self.last_action)
        if self.tile_coding:
            ob = self.tc.get_tiles(ob)
        
        if done:
            td_error, weight_change = self.update(ob, reward, done)
            # self.last_obs = self.env.reset()
            self.start()
        else:
            td_error, weight_change = self.update(ob, reward, done)
            self.last_obs = ob
            action_values = np.zeros(self.num_actions)
            for i in range(self.num_actions):
                action_values[i] = self.W[i][self.state_rep(self.last_obs)].sum()
            self.last_action = self.policy.get_action(action_values, exp_value=self.exploration.value(self.t))
            # try:
                # self.last_action = self.policy.get_action(action_values, self.state_rep(self.last_obs))
                # print('action_values', action_values)
                # print('self.W', self.W)
                # print('self.state_rep(self.last_obs)', self.state_rep(self.last_obs))
                # print('self.last_obs', self.last_obs)
                # print('self.W', np.sum(self.W[0]))
                # exit(0)
            # except:
            #     print('action_values', action_values)
            #     print('self.W', self.W)
            #     print('self.state_rep(self.last_obs)', self.state_rep(self.last_obs))
            #     print('self.last_obs', self.last_obs)
                # exit(0)
            self.last_action_value = action_values[self.last_action]
        
        self.t += 1
        return done, td_error, weight_change

    def offline_eval(self, num_episodes, environment_class, seed):
        # ob = eval_env.reset()
        episodic_steps = []
        episodic_returns = []

        if not hasattr(self, 'eval_env'):
            self.eval_env = environment_class()
            self.eval_env.seed(seed)

        last_obs = self.eval_env.reset()
        if self.tile_coding:
            last_obs = self.tc.get_tiles(last_obs)
        action_values = np.zeros(self.num_actions)
        for i in range(self.num_actions):
            action_values[i] = self.W[i][self.state_rep(last_obs)].sum()
        # last_action = self.policy.get_action(action_values, self.state_rep(last_obs))
        # last_action = random_argmax(action_values, self.state_rep(last_obs))
        last_action = random_argmax(action_values) #test this
        last_action_value = action_values[last_action]

        for ep_idx in range(num_episodes):
            steps = 0
            rewards = []
            done = False
            
            while not done:
                # a = random_argmax(self.q[horizon_step, self.state_rep(ob), :])
                # ob, reward, done, info = eval_env.step(a)
                ob, reward, done, _ = self.eval_env.step(last_action)
                if self.tile_coding:
                    ob = self.tc.get_tiles(ob)
                rewards.append(reward)
                steps += 1
                last_obs = ob
                action_values = np.zeros(self.num_actions)
                for i in range(self.num_actions):
                    action_values[i] = self.W[i][self.state_rep(last_obs)].sum()
                # last_action = random_argmax(action_values, self.state_rep(last_obs))
                last_action = random_argmax(action_values) #test this
                last_action_value = action_values[last_action]

            # if taking this step resulted in done, break, reset the env (and the latest observation)
            ob = self.eval_env.reset()
            episodic_steps.append(steps)
            episodic_returns.append(np.sum(rewards))
                    
        return episodic_returns, episodic_steps

    def q_learn_update(self, ob, reward, done):
        if self.step_size_schedule:
            self.step_size = (self.ep_len + 1) / ((self.ep_len + np.average(self.N[self.last_action][self.state_rep(self.last_obs)])) * 10)
            # if self.tile_coding:
            #     self.step_size /= self.num_tilings
        if done:
            delta = reward - self.last_action_value
        else:
            action_values = np.zeros(self.num_actions)
            for i in range(self.num_actions):
                action_values[i] = self.W[i][self.state_rep(ob)].sum()
            max_a = np.max(action_values)
            delta = reward + self.gamma * max_a - self.last_action_value
        if self.normalization == "td_squared" or self.normalization == "td_absolute": 
            self.policy.update_zeta(delta)
            # self.zetas.append(self.policy.zeta)
        # self.td_error.append(delta**2) # logging
        # self.policy.update_state(self.step_size, delta, self.state_rep(self.last_obs))
        grad = np.zeros_like(self.W)
        grad[self.last_action][self.state_rep(self.last_obs)] = 1
        weights_change = self.step_size * delta * grad
        self.W += weights_change
        # self.weights_changes.append(np.sqrt(np.sum((weights_change)**2))) # logging
        if not done:
            self.last_action = self.policy.get_action(action_values, exp_value=self.exploration.value(self.t))
            self.last_action_value = action_values[self.last_action]
        return delta**2, np.sqrt(np.sum((weights_change)**2))

    def expected_sarsa_update(self, ob, reward, done):
        if self.step_size_schedule:
            self.step_size = (self.ep_len + 1) / ((self.ep_len + np.average(self.N[self.last_action][self.state_rep(self.last_obs)])) * 10)
            # if self.tile_coding:
            #     self.step_size /= self.num_tilings
        if done:
            delta = reward - self.last_action_value
        else:
            action_values = np.zeros(self.num_actions)
            for i in range(self.num_actions):
                action_values[i] = self.W[i][self.state_rep(ob)].sum()
            p = self.policy.get_p(action_values, exp_value=self.exploration.value(self.t))
            exp_a = np.dot(p, action_values.flatten())
            delta = reward + self.gamma * exp_a - self.last_action_value
        if self.normalization == "td_squared" or self.normalization == "td_absolute": 
            self.policy.update_zeta(delta)
            # self.zetas.append(self.policy.zeta)
        # self.td_error.append(delta**2) # logging
        # self.policy.update_state(self.step_size, delta, self.state_rep(self.last_obs))
        grad = np.zeros_like(self.W)
        grad[self.last_action][self.state_rep(self.last_obs)] = 1
        weights_change = self.step_size * delta * grad
        self.W += weights_change
        # self.weights_changes.append(np.sqrt(np.sum((weights_change)**2))) # logging
        if not done:
            self.last_action = self.policy.get_action(action_values, exp_value=self.exploration.value(self.t))
            self.last_action_value = action_values[self.last_action]
        return delta**2, np.sqrt(np.sum((weights_change)**2))

    def state_rep(self, ob, give_position=1):
        if give_position == 1:
            return ob
        else:
            return self.to_one_hot(ob)

    def num_states(self, env):
        '''
        returns number of states for the given environment

        Inherit this class and implement this function wrt environment
        '''
        raise NotImplementedError

    def to_one_hot(self, obs):
        one_hot = np.zeros(self.num_states)
        one_hot[obs] = 1.0
        return one_hot

class DeepSeaLinearAgent(LinearAgent):
    def state_rep(self, ob, give_position=1):
        n = ob.shape[0]
        pos = np.argwhere(ob)[0]
        out = pos[0]*n + pos[1]
        if give_position == 1:
            return out
        else:
            return self.to_one_hot(out)

    def num_states(self, env):
        n = env.observation_spec().shape[0]
        return n**2


class RiverSwimLinearAgent(LinearAgent):
    def state_rep(self, ob, give_position=1):
        if give_position == 1:
            return ob
        else:
            return self.to_one_hot(ob)
    def num_states(self, env):    
        # Hard coded for now
        return 6

class TileCoder():
    def __init__(self, iht_size=4096, num_tilings=8, num_tiles=8, ob_dim=2, ob_low=None, ob_high=None):
        """
        Initializes the Tile Coder
        Initializers:
        iht_size -- int, the size of the index hash table, typically a power of 2
        num_tilings -- int, the number of tilings
        num_tiles -- int, the number of tiles. Here both the width and height of the
                     tile coder are the same
        Class Variables:
        self.iht -- tc.IHT, the index hash table that the tile coder will use
        self.num_tilings -- int, the number of tilings the tile coder will use
        self.num_tiles -- int, the number of tiles the tile coder will use
        """
        self.iht = tiles3.IHT(iht_size)
        self.num_tilings = num_tilings
        self.num_tiles = num_tiles
        self.ob_dim = ob_dim
        self.ob_low = ob_low
        self.ob_high = ob_high
        # self.scales = None
        # self.scales = [self.num_tiles / (ob_high[i] - ob_low[i]) for i in range(self.ob_dim)]
    
    def get_tiles(self, obs):
        """
        Takes in an observation from the environment
        and returns a numpy array of active tiles.
        """
        scales = np.zeros(self.ob_dim)
        for i in range(self.ob_dim):
            if self.ob_high[i] > 1e6 or self.ob_low[i] < -1e6:
                if not (isinstance(obs, list) or isinstance(obs, np.ndarray)):
                    scales[i] = self.num_tiles / np.sqrt([obs][i] * [obs][i] + 1)
                else:
                    scales[i] = self.num_tiles / np.sqrt(obs[i] * obs[i] + 1)
            else:
                scales[i] = self.num_tiles / (self.ob_high[i] - self.ob_low[i])
        if not (isinstance(obs, list) or isinstance(obs, np.ndarray)):
            tiles = tiles3.tiles(self.iht, self.num_tilings, [[obs][i] * scales[i] for i in range(self.ob_dim)])
        else:
            tiles = tiles3.tiles(self.iht, self.num_tilings, [obs[i] * scales[i] for i in range(self.ob_dim)])
        return np.array(tiles)
