import numpy as np
import torch
import torch.nn.functional as F
import os

from agent import Agent
import utils

import hydra
import wandb

LOG_FREQ = 100


class IncrementalAgent(Agent):
    """Incremental skill learning algorithm."""
    def __init__(self, obs_dim, t_obs_dim, action_dim, action_range, device,
                 critic_cfg, actor_cfg, discount, init_temperature, alpha_lr,
                 alpha_betas, actor_lr, actor_betas, actor_update_frequency,
                 critic_lr, critic_betas, critic_tau,
                 critic_target_update_frequency, batch_size,
                 learnable_temperature, use_t_obs, use_t_vel,
                 policy_use_t_vel, use_timesteps=True, finetune_prev=False):
        super().__init__()

        self.action_range = action_range
        self.device = torch.device(device)
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size
        self.learnable_temperature = learnable_temperature
        assert use_t_obs or use_t_vel, \
            "At least one of obs or vel should be used"
        self._use_t_vel = use_t_vel
        self._use_t_obs = use_t_obs
        self._use_timesteps = use_timesteps

        self.skill_actors = []

        self.critic_cfg = critic_cfg
        self.actor_cfg = actor_cfg
        self.init_temperature = init_temperature
        self.action_dim = action_dim
        self.actor_lr = actor_lr
        self.actor_betas = actor_betas
        self.critic_lr = critic_lr
        self.critic_betas = critic_betas
        self.alpha_lr = alpha_lr
        self.alpha_betas = alpha_betas

        self.actor = None
        self._finetune_prev = finetune_prev
        new_skill_actor = self.init_new_skill()
        self.current_skill_num = 0
        self.skill_actors.append(new_skill_actor)

        self.train()
        self.critic_target.train()

    def init_new_skill(self):
        self.critic = hydra.utils.instantiate(self.critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(self.critic_cfg).to(
            self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        if self.actor is not None:
            self.save_actor_model()
            self.current_skill_num += 1

        self.actor = hydra.utils.instantiate(self.actor_cfg).to(self.device)
        if self._finetune_prev:
            path = self._populate_path(None, 'skills')
            actor_path = os.path.join(path, 'last_actor.pt')
            if os.path.exists(actor_path):
                print('Loading actor')
                self.actor.load_state_dict(torch.load(actor_path))

        self.log_alpha = torch.tensor(
            np.log(self.init_temperature)).to(self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -self.action_dim

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.actor_lr,
                                                betas=self.actor_betas)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr,
                                                 betas=self.critic_betas)

        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=self.alpha_lr,
                                                    betas=self.alpha_betas)

        return self.actor

    def register_reward_module(self, reward_module):
        self.reward_module = reward_module
        self.max_timesteps = reward_module.max_timesteps
        if not self._use_timesteps:
            assert self.max_timesteps == 1, \
                (f"Must have exactly one, not {self.max_timesteps} timesteps,"
                 " if we're ignoring time.")

    def get_skill(self, skill_index):
        assert skill_index <= self.current_skill_num, "Skill not learned yet"
        return self.skill_actors[skill_index]

    def add_new_skill(self, num_steps_next_skill=None):
        self.skill_actors[-1].eval()
        new_actor = self.init_new_skill()
        self.skill_actors.append(new_actor)
        self.train()
        self.critic_target.train()
        return new_actor

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs, t_obs,
            prev_obs, prev_t_obs,
            sample=False, skill_index=-1):
        arglist = []
        for arg in [obs, t_obs, prev_obs, prev_t_obs]:
            tensor_arg = torch.FloatTensor(arg).to(self.device)
            tensor_arg = tensor_arg.unsqueeze(0)
            arglist.append(tensor_arg)
        dist = self.skill_actors[skill_index](*arglist)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])

    def update_critic(self, obs, t_obs,
                      prev_obs, prev_t_obs,
                      action, reward, next_obs, next_t_obs,
                      not_done, logger, step):
        dist = self.actor(next_obs, next_t_obs, obs, t_obs)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs,
                                                  next_t_obs,
                                                  obs,
                                                  t_obs,
                                                  next_action)
        target_V = torch.min(target_Q1,
                             target_Q2) - self.alpha.detach() * log_prob
        target_Q = reward + (not_done * self.discount * target_V)
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, t_obs,
                                             prev_obs, prev_t_obs,
                                             action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        logger.log('train_critic/loss', critic_loss, step)
        if step % LOG_FREQ == 0:
            wandb.log({
                'update_step': step,
                'train_critic/loss': critic_loss,
            })

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)

    def update_actor_and_alpha(self, obs, t_obs,
                               prev_obs, prev_t_obs,
                               logger, step):
        dist = self.actor(obs, t_obs, prev_obs, prev_t_obs)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_Q1, actor_Q2 = self.critic(obs, t_obs,
                                         prev_obs, prev_t_obs, action)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

        logger.log('train_actor/loss', actor_loss, step)
        # logger.log('train_actor/target_entropy', self.target_entropy, step)
        logger.log('train_actor/entropy', -log_prob.mean(), step)
        if step % LOG_FREQ == 0:
            wandb.log({
                'update_step': step,
                'train_actor/loss': actor_loss,
                'train_actor/target_entropy': self.target_entropy,
                'train_actor/entropy': -log_prob.mean()
            })

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_prob - self.target_entropy).detach()).mean()
            logger.log('train_alpha/loss', alpha_loss, step)
            logger.log('train_alpha/value', self.alpha, step)
            wandb.log({
                'update_step': step,
                'train_alpha/loss': alpha_loss,
                'train_alpha/value': self.alpha
            })
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update(self, replay_buffer, logger, step):
        sample = replay_buffer.sample(self.batch_size)
        (obs, t_obs, action, reward,
         next_obs, next_t_obs,
         prev_obs, prev_t_obs, timesteps,
         not_done, not_done_no_max) = sample

        reward = self.reward_module.get_rewards(next_t_obs,
                                                t_obs,
                                                timesteps,
                                                step=step)

        logger.log('train/batch_reward', reward.mean(), step)
        if step % LOG_FREQ == 0:
            wandb.log({
                'update_step': step,
                'reward': reward.mean(),
            })

        self.update_critic(obs, t_obs, prev_obs, prev_t_obs,
                           action, reward, next_obs, next_t_obs,
                           not_done_no_max, logger, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, t_obs,
                                        prev_obs, prev_t_obs,
                                        logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)

    def save_actor_model(self, path=None, filename=None):
        self._save_actor_index(path=path, filename=filename, index=-1)

    def _save_actor_index(self, path=None, filename=None, index=-1):
        path = self._populate_path(path, 'skills')
        if index == -1:
            skill_idx = str(self.current_skill_num)
        else:
            skill_idx = str(index)
        if filename is None:
            filename = f'skill_{skill_idx}'
        file_path = os.path.join(path, filename)
        actor = self.get_skill(index)
        torch.save(actor.state_dict(), file_path)

    def _save_current_critic(self, path=None, filename=None):
        path = self._populate_path(path, 'critic')
        if filename is None:
            filename = 'critic'
        file_path = os.path.join(path, filename)
        torch.save(self.critic.state_dict(), file_path)

    def _save_optimizers(self, path=None, filename=None):
        path = self._populate_path(path, 'auxilaries')
        filename = filename or 'aux.pth'
        file_path = os.path.join(path, filename)
        to_save = {
            'log_alpha': self.log_alpha,
            'log_alpha_optimizer': self.log_alpha_optimizer,
            'critic_optimizer': self.critic_optimizer.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'num_skills': self.current_skill_num
        }
        torch.save(to_save, file_path)

    def save_agent(self, path=None):
        skill_path, critic_path, aux_path = None, None, None
        if path is not None:
            skill_path, critic_path, aux_path = \
                self._make_individual_paths(path)
        self._save_current_critic(path=critic_path)
        self._save_optimizers(path=aux_path)
        for i in range(self.current_skill_num+1):
            self._save_actor_index(path=skill_path, index=i)

    def _load_agent(self, path=None):
        if path is None:
            path = os.getcwd()
        skill_path, critic_path, aux_path = \
            self._make_individual_paths(path)
        aux_dict = torch.load(os.path.join(aux_path, 'aux.pth'))
        self.current_skill_num = aux_dict['num_skills']
        self.skill_actors = []

        for i in range(self.current_skill_num):
            actor = hydra.utils.instantiate(self.actor_cfg).to(self.device)
            actor.load_state_dict(torch.load(f'{skill_path}/skill_{i}'))
            self.skill_actors.append(actor)
            self.actor = actor

        # Instead of loading the saved actor, just initializing an actor.
        # Doesn't matter since the new actor is novice anyway.
        self.actor = None
        self.add_new_skill()

    def _populate_path(self, path, default):
        work_dir = os.path.join(os.getcwd(), default)
        if path is None:
            path = work_dir
        utils.make_dir(path)
        return path

    def _make_individual_paths(self, path):
        return (os.path.join(path, 'skills'),
                os.path.join(path, 'critic'),
                os.path.join(path, 'auxilaries'))
