from typing import Optional
import numpy as np


class QLearningAgent:
    def __init__(
        self, 
        num_states: int, 
        num_actions: int, 
        learning_rate: float = 0.1, 
        discount_factor: float = 0.99, 
        epsilon: float = 0.1, 
        softmax: bool = False,
        softmax_temp: float = 1.0, 
    ):
        self.num_states = num_states
        self.num_actions = num_actions
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.softmax = softmax
        self.softmax_temp = softmax_temp
        
        self.q_values = np.zeros((num_states, num_actions))
    
    def select_action(self, state: int):
        q_values = self.q_values[state]
        if self.softmax:
            exp_q = np.exp(q_values / self.softmax_temp)
            action_probs = exp_q / np.sum(exp_q)
            action = np.random.choice(self.num_actions, p=action_probs)
        else:
            if np.random.rand() < self.epsilon:
                action = np.random.randint(self.num_actions)
            else:
                max_q_values = np.max(q_values)
                max_actions = np.where(q_values == max_q_values)[0]
                action = np.random.choice(max_actions)
        return action
    
    def update(
        self, 
        state: int, 
        action: int, 
        reward: float, 
        next_state: int, 
        done: bool, 
    ):
        q_next_star = 0 if done else np.max(self.q_values[next_state])
        td_target = reward + self.gamma * q_next_star
        td_error = td_target - self.q_values[state, action]
        self.q_values[state, action] += self.lr * td_error
    
    def select_greedy_action(self, state: int):
        max_q_values = np.max(self.q_values[state])
        max_actions = np.where(self.q_values[state] == max_q_values)[0]
        return int(np.random.choice(max_actions))
    

class EigenOptionAgent(QLearningAgent):
    def __init__(
        self, 
        num_states: int, 
        num_primitive_actions: int, 
        num_eigenoptions: int, 
        learning_rate: float = 0.1, 
        discount_factor: float = 0.99, 
        epsilon: float = 0.1, 
        softmax: bool = False,
        softmax_temp: float = 1.0, 
    ):
        super().__init__(
            num_states=num_states, 
            num_actions=num_primitive_actions + num_eigenoptions, 
            learning_rate=learning_rate, 
            discount_factor=discount_factor, 
            epsilon=epsilon, 
            softmax=softmax,
            softmax_temp=softmax_temp, 
        )
        self.num_primitive_actions = num_primitive_actions
        self.num_eigenoptions = num_eigenoptions
    
    def update_smdp(
        self, 
        start_state: int, 
        action: int, 
        total_reward: float, 
        end_state: int, 
        done: bool, 
        k: int, 
        trajectory_states=None,
    ):
        q_next = 0.0 if done else np.max(self.q_values[end_state])
        td_target = total_reward + (self.gamma ** k) * q_next
        td_error = td_target - self.q_values[start_state, action]
        self.q_values[start_state, action] += self.lr * td_error


class QLearningEigenoptionAgent:
    def __init__(
        self, 
        num_states: int, 
        num_primitive_actions: int,
        num_options: int,
        lr: float = 0.1, 
        discount_factor: float = 0.99,
        epsilon: float = 0.1,
        softmax: bool = False,
        softmax_temp: float = 1.0,
    ):
        self.num_states = num_states
        self.num_primitive_actions = num_primitive_actions
        self.num_options = num_options
        self.num_actions = num_primitive_actions + num_options

        self.lr = lr
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.softmax = softmax
        self.softmax_temp = softmax_temp

        self.state_representation = np.eye(num_states, dtype=float)
        self.w = np.zeros((num_states, self.num_actions), dtype=float)
    
    def compute_q_values(self, state: int):
        phi_s = self.state_representation[state, :]
        q = phi_s @ self.w
        return q
    
    def _safe_argmax(self, q):
        if not np.any(np.isfinite(q)):
            return int(np.random.randint(self.num_actions))
        q = np.nan_to_num(q, nan=-np.inf)
        best = np.flatnonzero(q == np.max(q))
        return int(np.random.choice(best))

    def select_action(self, state: int):
        if self.softmax:
            q = self.compute_q_values(state)
            q = np.nan_to_num(q, nan=-np.inf)
            q = q - np.max(q)
            exp_q = np.exp(q / self.softmax_temp)
            probs = exp_q / np.sum(exp_q)
            return int(np.random.choice(self.num_actions, p=probs))

        if np.random.rand() < self.epsilon:
            return int(np.random.randint(self.num_actions))
        return self._safe_argmax(self.compute_q_values(state))
    
    def select_greedy_action(self, state: int):
        return self._safe_argmax(self.compute_q_values(state))
    
    def update_smdp(
        self, 
        start_state: int, 
        action: int, 
        total_reward: float, 
        end_state: int, 
        done: bool, k: int, 
        trajectory_states=None, 
    ):
        q_sa = float(self.state_representation[start_state] @ self.w[:, action])
        if done:
            target = float(total_reward)
        else:
            q_next = self.compute_q_values(end_state)
            a_star = self._safe_argmax(q_next)
            v_next = float(self.compute_q_values(end_state)[a_star])
            target = float(total_reward) + (self.gamma ** k) * v_next
        
        td_error = target - q_sa
        self.w[:, action] += self.lr * td_error * self.state_representation[start_state]
        
        
class SRQLearningEigenoptionAgent:
    def __init__(
        self,
        num_states: int,
        num_primitive_actions: int, 
        num_options: int,
        lr: float = 0.1,
        lr_sr: float = 0.1,
        discount_factor: float = 0.99,
        discount_factor_sr: float = 0.99,
        epsilon: float = 0.1,
        softmax: bool = False,
        softmax_temp: float = 1.0,
        sr_init: Optional[np.ndarray] = None,
    ):
        self.num_states = num_states
        self.num_primitive_actions = num_primitive_actions
        self.num_options = num_options
        self.num_actions = num_primitive_actions + num_options
        
        self.lr = lr
        self.lr_sr = lr_sr
        self.gamma = discount_factor
        self.gamma_sr = discount_factor_sr
        self.epsilon = epsilon
        self.softmax = softmax
        self.softmax_temp = softmax_temp

        if sr_init is None:
            self.SR = np.eye(num_states, dtype=float)
        else:
            assert sr_init.shape == (num_states, num_states)
            self.SR = sr_init.astype(float, copy=True)
        self.w = np.zeros((num_states, self.num_actions), dtype=float)

    def compute_q_values(self, state: int):
        phi = self.SR[state, :]
        q = phi @ self.w
        return q

    def _safe_argmax(self, q):
        if not np.any(np.isfinite(q)):
            return int(np.random.randint(self.num_actions))
        q = np.nan_to_num(q, nan=-np.inf)
        best = np.flatnonzero(q == np.max(q))
        return int(np.random.choice(best))

    def select_action(self, state: int):
        if self.softmax:
            q = self.compute_q_values(state)
            q = np.nan_to_num(q, nan=-np.inf)
            q = q - np.max(q)
            exp_q = np.exp(q / self.softmax_temp)
            probs = exp_q / np.sum(exp_q)
            return int(np.random.choice(self.num_actions, p=probs))

        if np.random.rand() < self.epsilon:
            return int(np.random.randint(self.num_actions))

        return self._safe_argmax(self.compute_q_values(state))

    def select_greedy_action(self, state: int):
        return self._safe_argmax(self.compute_q_values(state))
    
    def update_smdp(
        self, 
        start_state: int,
        action: int,
        total_reward: float,
        end_state: int,
        done: bool,
        k: int,
        trajectory_states: list[int],
    ):
        phi_s = self.SR[start_state, :]
        q_sa = float(phi_s @ self.w[:, action])
        
        if done:
            target = float(total_reward)
        else:
            q_next = self.compute_q_values(end_state)
            a_star = self._safe_argmax(q_next)
            v_next = float(self.compute_q_values(end_state)[a_star])
            target = float(total_reward) + (self.gamma ** k) * v_next
        
        td_error = target - q_sa
        self.w[:, action] += self.lr * td_error * phi_s
        
        if k <= 0:
            return
        
        if len(trajectory_states) == 0:
            traj = [end_state]
        else:
            traj = trajectory_states
        
        traj = [start_state] + traj
        
        for i in range(len(traj)-1):
            s_curr = traj[i]
            s_next = traj[i+1]
            one_hot_s = np.zeros(self.num_states, dtype=float)
            one_hot_s[s_curr] = 1.0
            
            if i == len(traj) - 1 and done:
                sr_target = one_hot_s
            else:
                sr_target = one_hot_s + self.gamma_sr * self.SR[s_next, :]
            self.SR[s_curr, :] += self.lr_sr * (sr_target - self.SR[s_curr, :])
            
            
class HSRQLearningEigenoptionAgent:
    def __init__(
        self,
        num_states: int,
        num_primitive_actions: int,
        num_options: int,
        lr: float = 0.05,
        lr_hsr: float = 0.1,
        discount_factor: float = 0.99,
        discount_factor_hsr: float = 0.99,
        epsilon: float = 0.1,
        softmax: bool = False,
        softmax_temp: float = 1.0,
        hsr_init: Optional[np.ndarray] = None,
    ):
        self.num_states = num_states
        self.num_primitive_actions = num_primitive_actions
        self.num_options = num_options
        self.num_actions = num_primitive_actions + num_options

        self.lr = lr
        self.lr_hsr = lr_hsr

        self.gamma = discount_factor
        self.gamma_hsr = discount_factor_hsr

        self.epsilon = epsilon
        self.softmax = softmax
        self.softmax_temp = softmax_temp

        if hsr_init is None:
            self.HSR = np.eye(num_states, dtype=float)
        else:
            assert hsr_init.shape == (num_states, num_states)
            self.HSR = hsr_init.astype(float, copy=True)

        self.w = np.zeros((num_states, self.num_actions), dtype=float)

    def compute_q_values(self, state: int):
        phi = self.HSR[state, :]
        q = phi @ self.w
        return q

    def _safe_argmax(self, q):
        if not np.any(np.isfinite(q)):
            return int(np.random.randint(self.num_actions))
        q = np.nan_to_num(q, nan=-np.inf)
        best = np.flatnonzero(q == np.max(q))
        return int(np.random.choice(best))

    def select_action(self, state: int):
        if self.softmax:
            q = self.compute_q_values(state)
            q = np.nan_to_num(q, nan=-np.inf)
            q = q - np.max(q)
            exp_q = np.exp(q / self.softmax_temp)
            probs = exp_q / np.sum(exp_q)
            return int(np.random.choice(self.num_actions, p=probs))

        if np.random.rand() < self.epsilon:
            return int(np.random.randint(self.num_actions))

        return self._safe_argmax(self.compute_q_values(state))

    def select_greedy_action(self, state: int):
        return self._safe_argmax(self.compute_q_values(state))

    def _hsr_intra_feature(self, start_state: int, trajectory_states: list[int]):
        phi = np.zeros(self.num_states, dtype=float)
        phi[start_state] = 1.0

        g = self.gamma_hsr
        for s in trajectory_states[:-1]:
            phi[s] += g
            g *= self.gamma_hsr
        return phi

    def update_smdp(
        self,
        start_state: int,
        action: int,
        total_reward: float,
        end_state: int,
        done: bool,
        k: int,
        trajectory_states: list[int],
    ):
        phi_s = self.HSR[start_state, :]
        q_sa = float(phi_s @ self.w[:, action])

        if done:
            target = float(total_reward)
        else:
            q_next = self.compute_q_values(end_state)
            a_star = self._safe_argmax(q_next)
            v_next = float(self.compute_q_values(end_state)[a_star])
            target = float(total_reward) + (self.gamma ** k) * v_next

        td_error = target - q_sa
        self.w[:, action] += self.lr * td_error * phi_s

        if k <= 0:
            return

        if len(trajectory_states) == 0:
            traj = [end_state]
        else:
            traj = trajectory_states

        intra = self._hsr_intra_feature(start_state, traj)

        if done:
            intra[end_state] += (self.gamma_hsr ** k)
            target_h = intra
        else:
            target_h = intra + (self.gamma_hsr ** k) * self.HSR[end_state, :]

        self.HSR[start_state, :] += self.lr_hsr * (target_h - self.HSR[start_state, :])
        

class OfflineFeaturesEigenoptionAgent:
    def __init__(
        self, 
        num_states: int, 
        num_primitive_actions: int,
        num_options: int,
        lr_reward: float = 0.1, 
        discount_factor: float = 0.99,
        epsilon: float = 0.1,
        softmax: bool = False,
        softmax_temp: float = 1.0,
        state_representation: Optional[np.ndarray] = None,
    ):
        self.num_states = num_states
        self.num_primitive_actions = num_primitive_actions
        self.num_options = num_options
        self.num_actions = num_primitive_actions + num_options

        self.lr_reward = lr_reward
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.softmax = softmax
        self.softmax_temp = softmax_temp

        if state_representation is None:
            self.state_representation = np.eye(num_states, dtype=float)
        else:
            self.state_representation = state_representation
        self.w = np.zeros((num_states, self.num_actions), dtype=float)
    
    def compute_q_values(self, state: int):
        phi_s = self.state_representation[state, :]
        q = phi_s @ self.w
        return q
    
    def _safe_argmax(self, q):
        if not np.any(np.isfinite(q)):
            return int(np.random.randint(self.num_actions))
        q = np.nan_to_num(q, nan=-np.inf)
        best = np.flatnonzero(q == np.max(q))
        return int(np.random.choice(best))

    def select_action(self, state: int):
        if self.softmax:
            q = self.compute_q_values(state)
            q = np.nan_to_num(q, nan=-np.inf)
            q = q - np.max(q)
            exp_q = np.exp(q / self.softmax_temp)
            probs = exp_q / np.sum(exp_q)
            return int(np.random.choice(self.num_actions, p=probs))

        if np.random.rand() < self.epsilon:
            return int(np.random.randint(self.num_actions))
        return self._safe_argmax(self.compute_q_values(state))
    
    def select_greedy_action(self, state: int):
        return self._safe_argmax(self.compute_q_values(state))
    
    def update_smdp(
        self, 
        start_state: int, 
        action: int, 
        total_reward: float, 
        end_state: int, 
        done: bool, k: int, 
    ):
        q_sa = float(self.state_representation[start_state] @ self.w[:, action])
        if done:
            target = float(total_reward)
        else:
            q_next = self.compute_q_values(end_state)
            a_star = self._safe_argmax(q_next)
            v_next = float(self.compute_q_values(end_state)[a_star])
            target = float(total_reward) + (self.gamma ** k) * v_next
        
        td_error = target - q_sa
        self.w[:, action] += self.lr_reward * td_error * self.state_representation[start_state]
