import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import utils
import hydra

from agent import Agent
from agent.critic import DoubleQCritic
from agent.actor import DiagGaussianActor


class DQNAgent(Agent):
    """DQN algorithm."""

    def __init__(self, obs_dim, action_dim, device,
                 actor_cfg, discount,
                 actor_lr, actor_betas, actor_update_frequency,
                 actor_target_update_frequency,
                 max_grad_norm, actor_tau,
                 batch_size, exploration_rate, exploration_initial_eps, exploration_final_eps,
                 exploration_fraction):
        super().__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = torch.device(device)
        self.discount = discount
        self.actor_update_frequency = actor_update_frequency
        self.actor_target_update_frequency = actor_target_update_frequency
        self.batch_size = batch_size
        self.actor_cfg = actor_cfg
        self.actor_betas = actor_betas
        self.max_grad_norm = max_grad_norm
        self.actor_tau = actor_tau

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)
        self.actor_target = hydra.utils.instantiate(actor_cfg).to(self.device)
        self.exploration_rate = exploration_rate
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.exploration_fraction = exploration_fraction
        self.exploration_schedule = utils.get_linear_fn(self.exploration_initial_eps, self.exploration_final_eps,
                                                        self.exploration_fraction)

        self._current_progress_remaining = 1  # for updating exploration schedule, goes from 1 to 0
        # optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(),
            lr=actor_lr,
            betas=actor_betas)

        # change mode
        self.train()
        self.actor_target.train()

    def reset_actor(self):
        # reset actor
        self.actor = hydra.utils.instantiate(self.actor_cfg).to(self.device)
        self.actor_target = hydra.utils.instantiate(self.actor_cfg).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())

        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(),
            lr=self.actor_lr,
            betas=self.actor_betas)

    def train(self, training=True):
        self.training = training
        self.actor.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs, sample=False, deterministic=False):
        obs = torch.FloatTensor(obs).to(self.device)
        obs = obs.unsqueeze(0)
        # assert action.ndim == 2 and action.shape[0] == 1
        if not deterministic and np.random.rand() < self.exploration_rate:
            action = torch.randint(0, self.action_dim, (1,))
        else:
            _, action = self.actor(obs).max(dim=1)

        return utils.to_np(action)

    def save(self, model_dir, step):
        torch.save(
            self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.actor_target.state_dict(), '%s/actor_target_%s.pt' % (model_dir, step)
        )

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        self.actor_target.load_state_dict(
            torch.load('%s/actor_target_%s.pt' % (model_dir, step))
        )

    def update_actor(self, replay_buffer_sample, logger, step, print_flag=False):
        obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer_sample
        with torch.no_grad():
            # Compute the target Q values
            target_q = self.actor_target(next_obs)
            # Follow greedy policy: use the one with the highest value
            target_q, _ = target_q.max(dim=1)
            # Avoid potential broadcast issue
            target_q = target_q.reshape(-1, 1)
            # 1-step TD target
            target_q = reward + not_done * self.discount * target_q

        # Get current Q estimates
        current_q = self.actor(obs)

        # Retrieve the q-values for the actions from the replay buffer
        current_q = torch.gather(current_q, dim=1, index=action.long())

        # Compute Huber loss (less sensitive to outliers)
        loss = F.smooth_l1_loss(current_q, target_q)

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        loss.backward()
        # Clip gradient norm
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        self.actor_optimizer.step()

        if print_flag:
            logger.log('train_actor/loss', loss.detach(), step)
        self.actor.log(logger, step)

    def update(self, replay_buffer, logger, step, gradient_update=1):
        for index in range(gradient_update):
            replay_buffer_sample = replay_buffer.sample(
                self.batch_size)
            obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer_sample
            print_flag = False
            if index == gradient_update - 1:
                logger.log('train/batch_reward', reward.mean(), step)
                print_flag = True

            if step % self.actor_update_frequency == 0:
                self.update_actor(replay_buffer_sample, logger, step, print_flag)

        if step % self.actor_target_update_frequency == 0:
            utils.soft_update_params(self.actor, self.actor_target,
                                     self.actor_tau)

    def update_after_reset(self, replay_buffer, logger, step, gradient_update=1, policy_update=True):
        return
        # for index in range(gradient_update):
        #     obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample(
        #         self.batch_size)
        #
        #     print_flag = False
        #     if index == gradient_update - 1:
        #         logger.log('train/batch_reward', reward.mean(), step)
        #         print_flag = True
        #
        #     if index % self.actor_update_frequency == 0 and policy_update:
        #         self.update_actor_and_alpha(obs, logger, step, print_flag)
        #
        #     if step % self.actor_target_update_frequency == 0:
        #         utils.soft_update_params(self.actor, self.actor_target,
        #                                  self.actor_tau)

    def update_state_ent(self, replay_buffer, logger, step, gradient_update=1, K=5):
        return
        # for index in range(gradient_update):
        #     obs, full_obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample_state_ent(
        #         self.batch_size)
        #
        #     print_flag = False
        #     if index == gradient_update - 1:
        #         logger.log('train/batch_reward', reward.mean(), step)
        #         print_flag = True
        #
        #     if step % self.actor_update_frequency == 0:
        #         self.update_actor_and_alpha(obs, logger, step, print_flag)
        #
        # if step % self.actor_target_update_frequency == 0:
        #     utils.soft_update_params(self.actor, self.actor_target,
        #                              self.actor_tau)

    def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
        """
        Compute current progress remaining (starts from 1 and ends to 0)

        :param num_timesteps: current number of timesteps
        :param total_timesteps:
        """
        self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
