from mbrl.third_party.unrolled_actor_soft_critic.agent.unrolling import unroll_actor
from mbrl.models.one_dim_tr_model import OneDTransitionRewardModel
import pathlib

import hydra
import numpy as np
import omegaconf
import torch
import torch.nn.functional as F

import mbrl
from mbrl.third_party.unrolled_actor_soft_critic import utils
import mbrl.third_party.unrolled_actor_soft_critic as uasc


class UASCAgent(uasc.agent.Agent):
    """SAC algorithm."""

    def __init__(self, obs_dim, action_dim, action_range, batch_size, discount, device,
                 critic_cfg, actor_cfg, actor_train, critic_train, temperature_train):
        super().__init__()

        self.action_range = action_range
        self.batch_size = batch_size
        self.discount = discount
        self.device = torch.device(device)

        self.actor_train = actor_train
        self.critic_train = critic_train
        self.temperature_train = temperature_train

        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

        # Automatically tune temperature, with default target log-entropy -|A|:
        self.log_alpha = torch.tensor(np.log(temperature_train.initial)).to(self.device)
        self.log_alpha.requires_grad = True
        self.target_entropy = temperature_train.get("target_entropy", -action_dim)

        # Optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=actor_train.lr, betas=actor_train.betas)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=critic_train.lr, betas=critic_train.betas)
        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=temperature_train.lr, betas=temperature_train.betas)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs: np.ndarray, sample: bool = False, batched: bool = False, **_kwargs) -> np.ndarray:
        """Issues an action given an observation.

        Args:
            obs (np.ndarray): the observation (or batch of observations) for which the action
                is needed.
            sample (bool): if ``True`` the agent samples actions from its policy, otherwise it
                returns the mean policy value. Defaults to ``False``.
            batched (bool): if ``True`` signals to the agent that the obs should be interpreted
                as a batch.

        Returns:
            (np.ndarray): the action.
        """
        with utils.eval_mode(), torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            if not batched:
                obs = obs.unsqueeze(0)
            dist = self.actor(obs)
            action = dist.sample() if sample else dist.mean
            action = action.clamp(*self.action_range)
            if not batched:
                assert action.ndim == 2 and action.shape[0] == 1
                return utils.to_np(action[0])
            assert action.ndim == 2
            return utils.to_np(action)

    def update_critic(self, obs, action, reward, next_obs, not_done, logger, step):
        dist = self.actor(next_obs)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
        target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob
        target_Q = reward + (not_done * self.discount * target_V)
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Qs = self.critic(obs, action)
        critic_loss = sum(F.mse_loss(current_Q1, target_Q) for current_Q1 in current_Qs)
        logger.log("train_critic/loss", critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # self.critic.log(logger, step)

    def update(self, cfg : omegaconf.DictConfig, imagined_buffer : uasc.buffer.CircularReplayBuffer, dynamics_model : OneDTransitionRewardModel,
               logger=None, step=0, rng=np.random.default_rng(), device : torch.DeviceObjType = torch.device("cpu"),
               steps_unrolling=None):
        """Update all parts of the agent (actor, critic, and the critic target) by unrolling the actor.
        TODO: Test this with the critic.

        Args:
            cfg (omegaconf.DictConfig): configuration options, `cfg.algorithm.agent`.
            imagined_buffer (uasc.buffer.CircularReplayBuffer): buffer to sample from 
            dynamics_model (OneDTransitionRewardModel): dynamics/reward model to unroll.
            logger ([type], optional): output
            step (int, optional): number of times agent.update(...) has been called
            rng ([type], optional): random seed source
            device (torch.DeviceObjType, optional): device
        """
        obs, action, reward, next_obs, not_done, not_done_no_max = imagined_buffer.sample(
            self.batch_size)

        logger.log("train/batch_reward", reward.mean(), step)
        self.update_critic(obs, action, reward, next_obs, not_done_no_max, logger, step)

        if step % self.actor_train.update_frequency == 0:
            unroll_actor(cfg, imagined_buffer, dynamics_model, self, rng=rng, device=device, logger=logger, step=step, steps_unrolling=steps_unrolling)

        if step % self.critic_train.target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_train.tau)

    def save(self, save_dir):
        critic_path = save_dir / "critic.pth"
        actor_path = save_dir / "actor.pth"

        torch.save(self.critic.state_dict(), critic_path)
        torch.save(self.actor.state_dict(), actor_path)
