"""
https://github.com/jmichaux/dqn-pytorch
https://github.com/hungtuchen/pytorch-dqn
"""

import os.path as osp
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from copy import deepcopy

from gen_rl.commons.utils import logging, soft_update
from gen_rl.policy.env_models import DIM_LATENT


class DQN(object):
    def __init__(self, args: dict = None, **kwargs):
        self._args = args
        self._device = args["device"]
        self._rng = np.random.RandomState(self._args.get("seed", 2021))
        self.act_embed = args["act_embed"]
        self._num_actions = self.act_embed.shape[0]

        _state_dim = DIM_LATENT if args["if_use_latent_state"] else args["state_dim"]
        # _state_dim += 0 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else args["state_dim"]
        _state_dim *= 1 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else 2
        _action_dim = (args["action_dim"] if args["if_use_act_val_fn"] else 0)
        self.main_Q_net = nn.Sequential(
            nn.Linear(_state_dim + _action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        ).to(self._device)

        self.target_Q_net = deepcopy(self.main_Q_net)
        self.optimizer = optim.Adam(self.main_Q_net.parameters(), lr=0.001)
        logging(self.main_Q_net)

        self.s_tm1, self.a_tm1 = np.zeros((args["num_envs"], args["state_dim"])), np.zeros((args["num_envs"], 1))

    def select_action(self, state, epsilon=0.0, if_warmup=False):
        action = np.zeros((state.shape[0],))

        if epsilon > 0.0:
            # === Epsilon decay policy and retrieve the input for policy
            _rand = self._rng.uniform(low=0.0, high=1.0, size=state.shape[0])  # With prob epsilon select random actions
            mask = _rand < epsilon  # epsilon decay; eps-critic for dqn
        else:
            mask = np.asarray([False] * state.shape[0])
        random_obs, policy_obs = state[mask, :], state  # total would be: batch_step_size x dim_state

        if sum(~mask) > 0:
            self.main_Q_net.eval()
            with torch.no_grad():
                s = torch.tensor(np.tile(policy_obs[:, None, :], reps=(1, self._num_actions, 1)), device=self._device)
                a = self.act_embed.embedding_torch[None, ...].repeat(s.shape[0], 1, 1)
                if self._args["if_use_act_val_fn"]:
                    q_i = self.main_Q_net(torch.cat([s, a], dim=-1))
                else:
                    # Get the next latent state
                    ns = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])
                    r = self.reward_model((s, a))

                    if self._args["if_use_latent_state"]:
                        # Get the current latent state
                        s = torch.tensor(self.s_tm1, device=self._device).float()
                        a = self.act_embed.get(index=self.a_tm1).squeeze(1)
                        s = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])
                        s = s[:, None, :].repeat(1, self._num_actions, 1)

                    _in = torch.cat([ns, s], dim=-1) if self._args["if_use_prev_state"] else ns
                    q_i = self.main_Q_net(_in)
                    q_i += r
            self.main_Q_net.train()
            a = torch.topk(q_i.squeeze(-1), k=1).indices.cpu().detach().numpy().astype(np.int64).flatten()
        else:
            a = self._rng.randint(low=0, high=self._num_actions, size=(random_obs.shape[0],))

        if sum(~mask) > 0:
            action[~mask] = a[~mask]
            action[mask] = self._rng.randint(low=0, high=self._num_actions, size=(random_obs.shape[0],))
        else:
            action = a
        if self._args["if_use_latent_state"]:
            self.s_tm1, self.a_tm1 = state, action
        return action

    def update_policy(self, buffer, batch_size=256):
        obses_t, actions, rewards, obses_tp1, dones = buffer.sample(batch_size)
        if self._args["if_use_latent_state"]:
            (obses_tm1, obses_t), (actions_tm1, actions) = obses_t, actions
        with torch.no_grad():
            s = obses_tp1[:, None, :].repeat(1, self._num_actions, 1)
            a = self.act_embed.embedding_torch[None, ...].repeat(s.shape[0], 1, 1)
            if self._args["if_use_act_val_fn"]:
                targetQ = self.target_Q_net(torch.cat([s, a], dim=-1))
            else:
                # Get the next latent state
                ns = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])
                r = self.reward_model((s, a))

                if self._args["if_use_latent_state"]:
                    # Get the current latent state
                    s = torch.tensor(obses_t, device=self._device)
                    a = self.act_embed.get(index=actions)
                    s = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])
                    s = s[:, None, :].repeat(1, self._num_actions, 1)

                _in = torch.cat([ns, s], dim=-1) if self._args["if_use_prev_state"] else ns
                targetQ = self.target_Q_net(_in)
                targetQ += r
            target, indices = torch.topk(targetQ.squeeze(-1), k=1)

            if self._args["if_use_act_val_fn"]:
                target = rewards + self._args["discount"] * target * (1 - dones)
            else:
                # target = rewards + self._args["discount"] * target * (1 - dones)

                r = r.squeeze(-1)
                r = r[torch.arange(r.size(0)).unsqueeze(1), indices]
                target = rewards + self._args["discount"] * (r + self._args["discount"] * target) * (1 - dones)

        s = torch.tensor(obses_t, device=self._device)
        a = self.act_embed.get(index=actions)
        if self._args["if_use_act_val_fn"]:
            val_t = self.main_Q_net(torch.cat([s, a], dim=-1))
        else:
            # Get the next latent state
            ns = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])
            r = self.reward_model((s, a))

            if self._args["if_use_latent_state"]:
                # Get the current latent state
                s = torch.tensor(obses_tm1, device=self._device)
                a = self.act_embed.get(index=actions_tm1)
                s = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])

            _in = torch.cat([ns, s], dim=-1) if self._args["if_use_prev_state"] else ns
            val_t = self.main_Q_net(_in)
            val_t = r + self._args["discount"] * val_t
        bellmann_error = F.mse_loss(val_t, target)
        self.optimizer.zero_grad()
        bellmann_error.backward()
        nn.utils.clip_grad_norm_(self.main_Q_net.parameters(), 1.)
        self.optimizer.step()

        soft_update(target=self.target_Q_net, source=self.main_Q_net, tau=self._args["tau"])

        # ================ visualisation
        self.main_Q_net.eval()
        with torch.no_grad():
            s = obses_t[:, None, :].repeat(1, self._num_actions, 1)
            a = self.act_embed.embedding_torch[None, ...].repeat(s.shape[0], 1, 1)
            if self._args["if_use_act_val_fn"]:
                q_i = self.main_Q_net(torch.cat([s, a], dim=-1)).squeeze(-1)
            else:
                # Get the next latent state
                ns = self.state_model((s, a), if_return_latent=self._args["if_use_latent_state"])
                r = self.reward_model((s, a)).squeeze(-1)

                if self._args["if_use_latent_state"]:
                    # Get the current latent state
                    a = self.act_embed.get(index=actions_tm1).squeeze(1)
                    s = self.state_model((obses_tm1, a), if_return_latent=self._args["if_use_latent_state"])
                    s = s[:, None, :].repeat(1, self._num_actions, 1)

                _in = torch.cat([ns, s], dim=-1) if self._args["if_use_prev_state"] else ns
                q_i = self.main_Q_net(_in).squeeze(-1)
                q_i += r
        self.main_Q_net.train()
        return {
            "loss": bellmann_error.item(),
            "mean-q": q_i.mean(-1).mean(-1).item(),
            "std-q": q_i.var(-1).mean(-1).item()
        }

    def set_models(self, reward_model, state_model, decompose_obs_fn):
        if reward_model is not None: self.reward_model = reward_model
        if state_model is not None: self.state_model = state_model
        self._decompose_obs = decompose_obs_fn

    def update_models(self, buffer, batch_size=256):
        res = {}
        if self._args["if_train_models"]:
            if self._args["if_train_state_model"]:
                self.state_model.train()
            if self._args["if_train_reward_model"]:
                self.reward_model.train()
            obses_t, actions, rewards, obses_tp1, dones = buffer.sample(batch_size)
            if self._args["if_use_latent_state"]:
                (obses_tm1, obses_t), (actions_tm1, actions) = obses_t, actions
            a = self.act_embed.get(index=actions)

            if self._args["if_train_reward_model"]:
                self.reward_model.zero_grad()
                pred = self.reward_model((obses_t, a))
                r_loss = self.reward_model.criterion(rewards, pred)
                r_loss.backward()
                self.reward_model.optim.step()

            if self._args["if_train_state_model"]:
                self.state_model.zero_grad()
                pred = self.state_model((obses_t, a))
                s_loss = self.state_model.criterion(obses_tp1, pred)
                s_loss.backward()
                self.state_model.optim.step()

            if self._args["if_train_state_model"]:
                self.state_model.eval()
            if self._args["if_train_reward_model"]:
                self.reward_model.eval()

            res = {
                "dynamics-model-loss": s_loss.item() if self._args["if_train_state_model"] else 0.0,
                "reward-model-loss": r_loss.item() if self._args["if_train_reward_model"] else 0.0,
            }
        return res

    def _save(self, save_dir):
        if not osp.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self.main_Q_net.state_dict(), os.path.join(save_dir, f"main.pkl"))
        torch.save(self.target_Q_net.state_dict(), os.path.join(save_dir, f"target.pkl"))
        torch.save(self.optimizer.state_dict(), os.path.join(save_dir, f"opt.pkl"))

    def _load(self, load_dir):
        self.main_Q_net.load_state_dict(torch.load(os.path.join(load_dir, f"main.pkl")))
        self.target_Q_net.load_state_dict(torch.load(os.path.join(load_dir, f"target.pkl")))
        self.optimizer.load_state_dict(torch.load(os.path.join(load_dir, f"opt.pkl")))

    def _sync(self, tau: float = 0.0):
        if tau > 0.0:  # Soft update of params
            for param, target_param in zip(self.main_Q_net.parameters(), self.target_Q_net.parameters()):
                # tau * local_param.data + (1.0 - tau) * target_param.data
                target_param.data.copy_(tau * param.data + (1. - tau) * target_param.data)
        else:
            self.target_Q_net.load_state_dict(self.main_Q_net.state_dict())

    def _train(self):
        self.main_Q_net.train()
        self.target_Q_net.train()

    def _eval(self):
        self.main_Q_net.eval()
        self.target_Q_net.eval()

    def reset(self, **kwargs):
        pass
