import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils
from encoder import make_encoder, Transition_model, RNNModel

LOG_FREQ = 10000

def gaussian_logprob(noise, log_std):
    """Compute Gaussian log probability."""
    residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
    return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)


def squash(mu, pi, log_pi):
    """Apply squashing function.
    See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
    """
    mu = torch.tanh(mu)
    if pi is not None:
        pi = torch.tanh(pi)
    if log_pi is not None:
        log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
    return mu, pi, log_pi


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)



class Actor(nn.Module):
    """MLP actor network."""

    def __init__(
            self, obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
            predictor_update_encoder=False
    ):
        super().__init__()

        self.encoder_type = encoder_type
        self.encoder = make_encoder(
            encoder_type, obs_shape, encoder_feature_dim, num_layers,
            num_filters, output_logits=True
        )

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        if encoder_type == 'identity':
            input_dim = obs_shape[0]
        else:
            input_dim = encoder_feature_dim
        self.pre_process = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, encoder_feature_dim),
            nn.LayerNorm(encoder_feature_dim)
        )

        self.trunk = nn.Sequential(
            nn.Linear(encoder_feature_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2 * action_shape[0])
        )

        # state predictor
        self.rnn_obs = RNNModel(encoder_feature_dim + action_shape[0], output_size=encoder_feature_dim, hidden_dim=256)
        self.transition_model_obs = Transition_model(action_shape[0] + 2 * encoder_feature_dim, encoder_feature_dim, 2, output_logits=True)

        # action predictor
        self.rnn_action = RNNModel(encoder_feature_dim + action_shape[0], output_size=action_shape[0], hidden_dim=256)
        self.transition_model_action = Transition_model(action_shape[0]+ encoder_feature_dim, action_shape[0], 2,
                                                     output_logits=True)

        self.predictor_update_encoder = predictor_update_encoder

        self.outputs = dict()
        self.apply(weight_init)

    def forward(
            self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False, encoder_use_mean=False
    ):
        mu_encoder, logstd_encoder, obs = self.encoder(obs, output_mean=encoder_use_mean)
        obs = self.pre_process(obs)

        mu, log_std = self.trunk(obs).chunk(2, dim=-1)

        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 * (
                self.log_std_max - self.log_std_min
        ) * (log_std + 1)

        self.outputs['mu'] = mu
        self.outputs['std'] = log_std.exp()

        if compute_pi:
            std = log_std.exp()
            noise = torch.randn_like(mu)
            pi = mu + noise * std
        else:
            pi = None
            entropy = None

        if compute_log_pi:
            log_pi = gaussian_logprob(noise, log_std)
        else:
            log_pi = None

        mu, pi, log_pi = squash(mu, pi, log_pi)

        return mu, pi, log_pi, log_std

    def log(self, L, step, log_freq=LOG_FREQ):
        if step % log_freq != 0:
            return

        for k, v in self.outputs.items():
            L.log_histogram('train_actor/%s_hist' % k, v, step)

        L.log_param('train_actor/fc1', self.trunk[0], step)
        L.log_param('train_actor/fc2', self.trunk[2], step)
        L.log_param('train_actor/fc3', self.trunk[4], step)

    def obs_bound(self, obs, action, next_obs):
        """
        compute kl divergence between encoder and predictor
        """
        [encod_mu, encod_logstd, feat] = self.encoder(obs)
        if self.predictor_update_encoder:
            [pred_mu, pred_logstd, _] = self.transition_model_obs(feat, action)
        else:
            feat = feat.detach()
            z_a = torch.cat((feat[:-1, :, :], action[:-1, :, :]), dim=-1)
            obs_hidden = self.rnn_obs(z_a)
            [pred_mu, pred_logstd, _] = self.transition_model_obs(torch.cat([feat[-1, :, :], obs_hidden, action[-1, :, :]], dim=1))
        [encod_mu_next, encod_logstd_next, feat_next] = self.encoder(next_obs)

        # computer kl(encoder||transition)
        kl = pred_logstd - encod_logstd_next + (
                torch.exp(encod_logstd_next) ** 2 + (encod_mu_next - pred_mu) ** 2) / (
                     2 * torch.exp(pred_logstd) ** 2) - 0.5
        kl = torch.sum(kl, dim=-1)
        lw = -1.0 * kl
        return lw

    def action_bound(self, obs, action):
        """
        compute log q(a_t|z_1:t, a_1:t-1).
        """
        [encod_mu, encod_logstd, feat] = self.encoder(obs)
        z_a = torch.cat((feat[:-1, :, :], action[:-1, :, :]), dim=-1)
        action_hidden = self.rnn_action(z_a)
        [_, _, dist] = self.transition_model_action(torch.cat([feat[-1, :, :], action_hidden], dim=1))
        log_prob = dist.log_prob(action[-1, :, :])
        log_prob = log_prob.sum(-1)
        return log_prob

    def bound(self, obs, action, next_obs):
        obs_bound = self.obs_bound(obs, action, next_obs)
        action_bound = self.action_bound(obs, action)
        bound = obs_bound + action_bound
        return bound

class QFunction(nn.Module):
    """MLP for q-function."""

    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)

        obs_action = torch.cat([obs, action], dim=1)
        return self.trunk(obs_action)


class Critic(nn.Module):
    """Critic network, employes two q-functions."""

    def __init__(
            self, obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
    ):
        super().__init__()

        self.Q1 = QFunction(
            obs_shape[0], action_shape[0], hidden_dim
        )
        self.Q2 = QFunction(
            obs_shape[0], action_shape[0], hidden_dim
        )

        self.outputs = dict()
        self.apply(weight_init)

    def forward(self, obs, action, detach_encoder=False):

        q1 = self.Q1(obs, action)
        q2 = self.Q2(obs, action)

        self.outputs['q1'] = q1
        self.outputs['q2'] = q2

        return q1, q2

    def log(self, L, step, log_freq=LOG_FREQ):
        if step % log_freq != 0:
            return

        for k, v in self.outputs.items():
            L.log_histogram('train_critic/%s_hist' % k, v, step)

        for i in range(3):
            L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
            L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)

class TCSacAgent(object):
    """TC-SAC"""
    def __init__(
            self,
            obs_shape,
            action_shape,
            device,
            hidden_dim=256,
            discount=0.99,
            init_temperature=0.01,
            alpha_lr=1e-4,
            alpha_beta=0.9,
            actor_lr=1e-4,
            actor_beta=0.9,
            actor_log_std_min=-10,
            actor_log_std_max=2,
            actor_update_freq=2,
            critic_lr=1e-4,
            critic_beta=0.9,
            critic_tau=0.005,
            critic_target_update_freq=2,
            encoder_type='prop',
            encoder_feature_dim=50,
            encoder_tau=0.005,
            num_layers=4,
            num_filters=32,
            log_interval=100,
            detach_encoder=False,
            use_aug_reward=True,
            kl_coef_lr=1e-4,
            kl_constraint=3.0,
            horizon=5,
            init_log_kl=1e-6
    ):
        self.device = device
        self.discount = discount
        self.critic_tau = critic_tau
        self.encoder_tau = encoder_tau
        self.actor_update_freq = actor_update_freq
        self.critic_target_update_freq = critic_target_update_freq
        self.log_interval = log_interval
        self.image_size = obs_shape[-1]
        self.detach_encoder = detach_encoder
        self.encoder_type = encoder_type
        self.use_aug_reward = use_aug_reward
        self.horizon = horizon

        self.log_kl_coeff = torch.tensor(np.log(init_log_kl)).to(device)
        self.log_kl_coeff.requires_grad = True
        self.kl_constraint = kl_constraint

        self.critic = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
        ).to(device)

        self.actor = Actor(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, actor_log_std_min, actor_log_std_max,
            num_layers, num_filters
        ).to(device)

        self.critic_target = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
        ).to(device)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -np.prod(action_shape)

        # optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
        )

        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
        )

        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
        )

        self.dual_kl_optimizer = torch.optim.Adam(
            [self.log_kl_coeff], lr=kl_coef_lr, betas=(alpha_beta, 0.999)
        )

        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def select_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            mu, _, _, _ = self.actor(
                obs, compute_pi=False, compute_log_pi=False, encoder_use_mean=True
            )
            return mu.cpu().data.numpy().flatten()

    def sample_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
            return pi.cpu().data.numpy().flatten()

    def update_critic(self, obs, action, reward, next_obs, L, step):
        with torch.no_grad():
            _, policy_action, log_pi, _ = self.actor(next_obs)
            target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
            target_V = torch.min(target_Q1,
                                 target_Q2) - self.alpha.detach() * log_pi
            target_Q = reward + (self.discount * target_V)

        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1,
                                 target_Q) + F.mse_loss(current_Q2, target_Q)

        if step % self.log_interval == 0:
            L.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

    def update_actor_and_alpha(self, obs, action, next_obs, env_reward, L, step):
        if self.use_aug_reward:
            _, pi, log_pi, log_std = self.actor(obs[-1, :, :])
            _, pi_next, log_pi_next, log_std_next = self.actor(next_obs)
            actor_Q1, actor_Q2 = self.critic(next_obs, pi_next)
            actor_Q_next = torch.min(actor_Q1, actor_Q2)
            actor_loss = self.discount * (self.alpha.detach() * log_pi_next - actor_Q_next)

            lw = self.actor.bound(obs, action, next_obs)
            kl_coeff = torch.exp(self.log_kl_coeff).detach()

            actor_loss = actor_loss - kl_coeff * lw.unsqueeze(dim=1)

            actor_loss = torch.mean(actor_loss - env_reward + self.alpha.detach() * log_pi)
        else:
            _, pi, log_pi, log_std = self.actor(obs)
            actor_Q1, actor_Q2 = self.critic(obs, pi)

            actor_Q = torch.min(actor_Q1, actor_Q2)
            actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()

        if step % self.log_interval == 0:
            L.log('train_actor/loss', actor_loss, step)
            L.log('train_actor/target_entropy', self.target_entropy, step)
            if self.use_aug_reward:
                L.log('train_actor/lw', lw.mean(), step)
                L.log('train_actor/kl_coeff', kl_coeff, step)

        entropy = 0.5 * log_std.shape[1] * \
                  (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
        if step % self.log_interval == 0:
            L.log('train_actor/entropy', entropy.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()


        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha *
                      (-log_pi - self.target_entropy).detach()).mean()
        if step % self.log_interval == 0:
            L.log('train_alpha/loss', alpha_loss, step)
            L.log('train_alpha/alpha_value', self.alpha, step)
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

    def calulate_reward(self, obs, action, next_obs, L, step):

        with torch.no_grad():
            lw = self.actor.bound(obs,action, next_obs)             # (B,)

            reward = torch.exp(self.log_kl_coeff) * lw

        return reward, lw

    def update_dual_multiplier(self, lw, L, step):

        dual_kl_loss = torch.exp(self.log_kl_coeff) * (lw.mean().detach() - self.kl_constraint)

        self.dual_kl_optimizer.zero_grad()
        dual_kl_loss.backward()
        self.dual_kl_optimizer.step()

        # clip log_kl_coeff and log_mi_coeff to (-log(1e6), log(1e6)) range
        self.log_kl_coeff.data = torch.clamp(self.log_kl_coeff, -1.0 * np.log(1e6), np.log(1e6)).data

        if step % self.log_interval == 0:
            L.log('train_dual/dual_kl_loss', dual_kl_loss, step)
            L.log('train_dual/kl', lw.mean(), step)
            L.log('train_dual/log_kl_coeff', self.log_kl_coeff, step)

    def update(self, replay_buffer, L, step):

        obs, action, reward, next_obs, not_done = replay_buffer.sample_consecutive(self.horizon)

        if self.use_aug_reward:
            intrinsic_reward, lw = self.calulate_reward(obs, action, next_obs, L, step)
            intrinsic_reward = intrinsic_reward.unsqueeze(1).detach()
            aug_reward = reward + intrinsic_reward
        else:
            aug_reward = reward

        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)
            if self.use_aug_reward:
                L.log('train/intrinsic_reward', intrinsic_reward.mean(), step)
                L.log('train/augmented_reward', aug_reward.mean(), step)

        if self.use_aug_reward:
            self.update_dual_multiplier(lw, L, step)

        self.update_critic(obs[-1], action[-1], aug_reward, next_obs, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, action, next_obs, reward, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )

    def save(self, model_dir, step):
        torch.save(
            self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
        )


    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )

