from copy import deepcopy
from functools import lru_cache
import pickle
from typing import Callable, Union

import numpy as np
import torch as th

import wandb
from gpi.rl_algorithm import RLAlgorithm


class GPI(RLAlgorithm):
    def __init__(
        self,
        env,
        algorithm_constructor: Callable,
        h_step: int = 0,
        log: bool = True,
        project_name: str = "gpi",
        experiment_name: str = "gpi",
        device: Union[th.device, str] = "auto",
    ):
        super(GPI, self).__init__(env, device)

        self.algorithm_constructor = algorithm_constructor
        self.h_step = h_step
        self.policies = []
        self.tasks = []
        self.w = None
        self.mpc = False

        self.log = log
        if self.log:
            self.setup_wandb(project_name, experiment_name)

    def __getstate__(self):
        state = self.__dict__.copy()
        del state["algorithm_constructor"]
        del state["writer"]
        return state

    def save(self, path):
        with open(path, "wb") as f:
            pickle.dump(self, f)

    def load(path):
        with open(path, "rb") as f:
            return pickle.load(f)

    def _gpi(self, obs, return_policy_index=False):
        q_vals = np.stack([policy.q_values(obs, self.w) for policy in self.policies])
        policy_index, action = np.unravel_index(np.argmax(q_vals), q_vals.shape)
        action = int(action)
        if return_policy_index:
            return action, policy_index
        return action

    @lru_cache(maxsize=None)
    def _h_step_gpi(self, obs, n):
        if n == 0:
            if self.mpc:
                return 0.0
            q_vals = np.stack([policy.q_values(obs, self.w) for policy in self.policies])
            return np.max(q_vals)
        else:
            returns = []
            for a_t in range(self.env.action_space.n):
                ret = 0.0
                for ((next_obs, r, terminal), prob) in self.policies[0].model.transitions(obs, a_t):
                    if next_obs is None:
                        if not self.mpc:
                            q_vals = np.stack([policy.q_values(obs, self.w)[a_t] for policy in self.policies])
                            max_q = np.max(q_vals)
                            ret = max_q
                    elif terminal:
                        ret += prob * np.dot(np.array(r), self.w)
                    else:
                        ret += prob * (np.dot(np.array(r), self.w) + self.gamma * self._h_step_gpi(next_obs, n - 1))
                returns.append(ret)
            return max(returns)

    def eval(self, obs, w, return_policy_index=False) -> int:
        self.w = w
        if self.h_step > 0:
            self._h_step_gpi.cache_clear()
            best_return = -np.inf
            best_action = 0
            for a0 in range(self.policies[0].action_dim):
                ret = 0.0
                for ((next_obs, r, terminal), prob) in self.policies[0].model.transitions(obs, a0):
                    if next_obs is None:
                        if not self.mpc:
                            q_vals = np.stack([policy.q_values(obs, self.w)[a0] for policy in self.policies])
                            max_q = np.max(q_vals)
                            ret = max_q
                    elif terminal:
                        ret += prob * np.dot(np.array(r), self.w)
                    else:
                        ret += prob * (np.dot(np.array(r), self.w) + self.gamma * self._h_step_gpi(next_obs, self.h_step - 1))

                if ret > best_return:
                    best_return = ret
                    best_action = a0
            
            if return_policy_index:
                return best_action, -1
            return best_action

        else:
            return self._gpi(obs, return_policy_index)

    def max_q(self, obs, w):
        q_vals = np.stack([policy.q_values(obs, w) for policy in self.policies])
        policy_ind, action = np.unravel_index(np.argmax(q_vals), q_vals.shape)
        action = int(action)
        return self.policies[policy_ind].q_table[tuple(obs)][action]

    def delete_policies(self, delete_indx):
        for i in sorted(delete_indx, reverse=True):
            self.policies.pop(i)
            self.tasks.pop(i)

    def learn(
        self,
        w,
        total_timesteps,
        total_episodes=None,
        reset_num_timesteps=False,
        eval_env=None,
        eval_freq=1000,
        use_gpi=True,
        reset_learning_starts=True,
        new_policy=True,
        reuse_value_ind=None,
    ):
        if new_policy:
            new_policy = self.algorithm_constructor()
            self.policies.append(new_policy)
        self.tasks.append(w)

        self.policies[-1].gpi = self if use_gpi else None

        if self.log:
            self.policies[-1].log = self.log
            self.policies[-1].writer = self.writer
            wandb.config.update(self.policies[-1].get_config())

        if len(self.policies) > 1:
            self.policies[-1].num_timesteps = self.policies[-2].num_timesteps
            self.policies[-1].num_episodes = self.policies[-2].num_episodes
            if reset_learning_starts:
                self.policies[-1].learning_starts = self.policies[-2].num_timesteps  # to reset exploration schedule

            if reuse_value_ind is not None:
                if hasattr(self.policies[-1], "q_table"):
                    self.policies[-1].q_table = deepcopy(self.policies[reuse_value_ind].q_table)
                else:
                    self.policies[-1].psi_net.load_state_dict(self.policies[reuse_value_ind].psi_net.state_dict())
                    self.policies[-1].target_psi_net.load_state_dict(self.policies[reuse_value_ind].psi_net.state_dict())

            if hasattr(self.policies[-1], 'replay_buffer'):
                self.policies[-1].replay_buffer = self.policies[-2].replay_buffer  # use shared buffer

            if self.policies[-1].dyna:
                self.policies[-1].model = self.policies[-2].model
                if self.policies[-1].per:
                    self.policies[-1].reset_priorities(w)

        self.policies[-1].learn(
            w=w,
            total_timesteps=total_timesteps,
            total_episodes=total_episodes,
            reset_num_timesteps=reset_num_timesteps,
            eval_env=eval_env,
            eval_freq=eval_freq,
        )

    @property
    def gamma(self):
        return self.policies[0].gamma

    def train(self):
        pass

    def get_config(self) -> dict:
        if len(self.policies) > 0:
            return self.policies[0].get_config()
        return {}

