import numpy as np

import wandb as wb
from gpi.dynamics.tabular_model import TabularModel
from gpi.rl_algorithm import RLAlgorithm
from gpi.successor_features.gpi import GPI
from gpi.utils.eval import eval_mo
from gpi.utils.utils import linearly_decaying_epsilon


class SF(RLAlgorithm):
    def __init__(
        self,
        env,
        alpha: float = 0.01,
        gamma: float = 0.99,
        initial_epsilon: float = 0.01,
        final_epsilon: float = 0.01,
        epsilon_decay_steps: int = None,
        learning_starts: int = 0,
        dyna: bool = False,
        buffer_size: int = 10000,
        dyna_updates: int = 5,
        dyna_deterministic_dynamics: bool = False,
        per: bool = False,
        min_priority: float = 0.01,
        alpha_per: float = 0.6,
        gpi: GPI = None,
        use_gpi: bool = False,
        project_name: str = "sf",
        experiment_name: str = "sf",
        log: bool = False,
    ):

        super().__init__(env, device=None)
        self.phi_dim = env.reward_space.shape[0]
        self.alpha = alpha
        self.gamma = gamma
        self.initial_epsilon = initial_epsilon
        self.epsilon = initial_epsilon
        self.final_epsilon = final_epsilon
        self.epsilon_decay_steps = epsilon_decay_steps
        self.learning_starts = learning_starts
        self.gpi = gpi
        self.use_gpi = use_gpi
        self.dyna = dyna
        self.buffer_size = buffer_size
        self.dyna_updates = dyna_updates
        self.per = per
        self.min_priority = min_priority
        self.alpha_per = alpha_per

        self.q_table = dict()
        if self.dyna:
            self.model = TabularModel(deterministic=dyna_deterministic_dynamics, prioritize=self.per)
        else:
            self.model = None
            self.replay_buffer = None

        self.log = log
        if self.log:
            self.setup_wandb(project_name, experiment_name)

    def __getstate__(self):
        state = self.__dict__.copy()
        del state["env"]
        del state["writer"]
        return state

    def act(self, obs: np.array, w: np.array):
        np_obs = obs
        obs = tuple(obs)
        if obs not in self.q_table:
            self.q_table[obs] = np.zeros((self.action_dim, self.phi_dim))
        self.policy_index = None
        if np.random.rand() < self.epsilon:
            return int(self.env.action_space.sample())
        else:
            if self.gpi is not None and self.use_gpi:
                action, self.policy_index = self.gpi.eval(np_obs, w, return_policy_index=True)
                return action
            else:
                return int(np.argmax(np.dot(self.q_table[obs], w)))

    def max_action(self, obs: np.array, w: np.array):
        obs = tuple(obs)
        if obs not in self.q_table:
            self.q_table[obs] = np.zeros((self.action_dim, self.phi_dim))
        return int(np.argmax(np.dot(self.q_table[obs], w)))

    def eval(self, obs: np.array, w: np.array) -> int:
        if self.gpi is not None and self.use_gpi:
            return self.gpi.eval(obs, w)
        else:
            obs = tuple(obs)
            if obs not in self.q_table:
                return int(self.env.action_space.sample())
            return int(np.argmax(np.dot(self.q_table[obs], w)))

    def add_noise(self, noise: float):
        for obs in self.q_table.keys():
            self.q_table[obs] += np.random.randn(*self.q_table[obs].shape) * noise

    def q_values(self, obs: np.array, w: np.array) -> np.array:
        obs = tuple(obs)
        if obs not in self.q_table:
            self.q_table[obs] = np.zeros((self.action_dim, self.phi_dim))
        return np.dot(self.q_table[obs], w)

    def train(self, w: np.array):
        obs = tuple(self.obs)
        next_obs = tuple(self.next_obs)
        if next_obs not in self.q_table:
            self.q_table[next_obs] = np.zeros((self.action_dim, self.phi_dim))

        if self.gpi is not None:
            max_q = self.q_table[next_obs][self.gpi.eval(self.next_obs, w)]
        else:
            max_q = self.q_table[next_obs][np.argmax(np.dot(self.q_table[next_obs], w))]
        td_error = self.reward + (1 - self.terminal) * self.gamma * max_q - self.q_table[obs][self.action]

        self.q_table[obs][self.action] += self.alpha * td_error

        # Update other policies
        if self.gpi is not None:
            for i in range(len(self.gpi.policies) - 1):
                pi = self.gpi.policies[i]
                if next_obs not in pi.q_table:
                    pi.q_table[next_obs] = np.zeros((self.action_dim, self.phi_dim))
                pi_w = self.gpi.tasks[i]
                pi_max_q = pi.q_table[next_obs][self.gpi.eval(self.next_obs, pi_w)]
                pi_td_error = self.reward + (1 - self.terminal) * self.gamma * pi_max_q - pi.q_table[obs][self.action]
                pi.q_table[obs][self.action] += self.alpha * pi_td_error

        if self.dyna:
            if self.per:
                priority = np.abs(np.dot(td_error, w))
                self.model.update(obs, self.action, self.reward, next_obs, self.terminal, max(priority, self.min_priority)**self.alpha_per)
            else:
                self.model.update(obs, self.action, self.reward, next_obs, self.terminal)

            priorities = []
            priorities_td = []
            for i in range(self.dyna_updates):
                if self.per:
                    s, a, r, next_s, terminal, ind = self.model.random_transition()
                else:
                    s, a, r, next_s, terminal = self.model.random_transition()

                if s not in self.q_table:
                    self.q_table[s] = np.zeros((self.action_dim, self.phi_dim))
                if next_s not in self.q_table:
                    self.q_table[next_s] = np.zeros((self.action_dim, self.phi_dim))

                if self.gpi is not None:
                    max_q = self.q_table[next_s][self.gpi.eval(np.array(next_s), w)]
                else:
                    max_q = self.q_table[next_s][np.argmax(np.dot(self.q_table[next_s], w))]
                td_err = r + (1 - terminal) * self.gamma * max_q - self.q_table[s][a]
                self.q_table[s][a] += self.alpha * td_err

                if self.per:
                    priority = np.abs(np.dot(td_err, w))
                    priority = max(priority, self.min_priority) ** self.alpha_per
                    priorities.append(priority)
                    priority_td = r + (1 - terminal) * self.gamma * self.q_table[next_s][self.gpi.eval(np.array(next_s), w)] - self.q_table[s][a]
                    priority_td = np.abs(np.dot(priority_td, w))
                    priority_td = max(priority_td, self.min_priority) ** self.alpha_per
                    priorities_td.append(priority_td)
                    self.model.update_priority(ind, priority)

                if self.gpi is not None:
                    for i in range(len(self.gpi.policies) - 1):
                        pi = self.gpi.policies[i]
                        if next_s not in pi.q_table:
                            pi.q_table[next_s] = np.zeros((self.action_dim, self.phi_dim))
                        pi_w = self.gpi.tasks[i]
                        pi_max_q = pi.q_table[next_s][self.gpi.eval(np.array(next_s), pi_w)]
                        pi_td_error = r + (1 - terminal) * self.gamma * pi_max_q - pi.q_table[s][a]
                        pi.q_table[s][a] += self.alpha * pi_td_error

        if self.epsilon_decay_steps is not None:
            self.epsilon = linearly_decaying_epsilon(self.initial_epsilon, self.epsilon_decay_steps, self.num_timesteps, self.learning_starts, self.final_epsilon)

        if self.log and self.num_timesteps % 100 == 0:
            self.writer.add_scalar("losses/td_error", np.dot(td_error, w), self.num_timesteps)
            self.writer.add_scalar("metrics/epsilon", self.epsilon, self.num_timesteps)
            if self.per:
                self.writer.add_scalar("metrics/max_priority", max(priorities), self.num_timesteps)
                self.writer.add_scalar("metrics/mean_priority", np.mean(priorities), self.num_timesteps)
                self.writer.add_scalar("metrics/mean_priority_td", np.mean(priorities_td), self.num_timesteps)
                self.writer.add_scalar("metrics/max_priority_td", max(priorities_td), self.num_timesteps)

    def reset_priorities(self, w):
        for i in range(len(self.model.state_actions_pairs)):
            sa = self.model.state_actions_pairs[i]
            s = sa[0]
            a = sa[1]
            next_s, r, terminal = self.model.predict(s, a)

            if s not in self.q_table:
                self.q_table[s] = np.zeros((self.action_dim, self.phi_dim))
            if next_s not in self.q_table:
                self.q_table[next_s] = np.zeros((self.action_dim, self.phi_dim))
            
            max_q = self.q_table[next_s][np.argmax(np.dot(self.q_table[next_s], w))]
            pr = r + (1 - terminal) * self.gamma * max_q - self.q_table[s][a]
            pr = np.abs(np.dot(pr, w))
            self.model.update_priority(i, max(pr, self.min_priority)**self.alpha_per)

    def get_config(self) -> dict:
        return {
            "alpha": self.alpha,
            "gamma": self.gamma,
            "initial_epsilon": self.initial_epsilon,
            "final_epsilon": self.final_epsilon,
            "epsilon_decay_steps": self.epsilon_decay_steps,
            "dyna": self.dyna,
            "buffer_size": self.buffer_size,
            "dyna_updates": self.dyna_updates,
            "min_priority": self.min_priority,
            "alpha_per": self.alpha_per,
            "per": self.per,
        }

    def learn(
        self,
        total_timesteps,
        total_episodes=None,
        reset_num_timesteps=True,
        eval_env=None,
        eval_freq=1000,
        w=np.array([1.0, 0.0]),
    ):
        episode_reward = 0.0
        episode_vec_reward = np.zeros(w.shape[0])
        num_episodes = 0
        self.obs, _ = self.env.reset()
        done = False

        self.env.w = w

        self.num_timesteps = 0 if reset_num_timesteps else self.num_timesteps
        self.num_episodes = 0 if reset_num_timesteps else self.num_episodes
        for _ in range(1, total_timesteps + 1):
            if total_episodes is not None and num_episodes == total_episodes:
                break

            self.num_timesteps += 1

            self.action = self.act(self.obs, w)
            self.next_obs, reward, terminated, truncated, info = self.env.step(self.action)
            self.reward = info["vector_reward"]  # vectorized reward
            self.terminal = terminated
            done = terminated or truncated

            self.train(w)

            if eval_env is not None and self.log and self.num_timesteps % eval_freq == 0:
                (total_reward, discounted_return, total_vec_r, total_vec_return) = eval_mo(self, eval_env, w)
                self.writer.add_scalar("eval/total_reward", total_reward, self.num_timesteps)
                self.writer.add_scalar("eval/discounted_return", discounted_return, self.num_timesteps)
                for i in range(episode_vec_reward.shape[0]):
                    self.writer.add_scalar(f"eval/total_reward_obj{i}", total_vec_r[i], self.num_timesteps)
                    self.writer.add_scalar(f"eval/return_obj{i}", total_vec_return[i], self.num_timesteps)

            episode_reward += reward
            episode_vec_reward += info["vector_reward"]
            if done:
                self.obs, _ = self.env.reset()
                done = False
                num_episodes += 1
                self.num_episodes += 1

                if num_episodes % 1000 == 0:
                    print(f"Episode: {self.num_episodes} Step: {self.num_timesteps}, Ep. Total Reward: {episode_reward}, {episode_vec_reward}")
                if self.log:
                    self.writer.add_scalar("metrics/episode", self.num_episodes, self.num_timesteps)
                    self.writer.add_scalar("metrics/episode_reward", episode_reward, self.num_timesteps)
                    for i in range(episode_vec_reward.shape[0]):
                        self.writer.add_scalar(f"metrics/episode_reward_obj{i}", episode_vec_reward[i], self.num_timesteps)

                episode_reward = 0.0
                episode_vec_reward = np.zeros(w.shape[0])
            else:
                self.obs = self.next_obs
